Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
b3a1fcf4
Commit
b3a1fcf4
authored
May 01, 2025
by
Aubrey Li
Browse files
ktransformers/utils: fix _get_logits_warper error
parent
7530491f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
64 additions
and
8 deletions
+64
-8
ktransformers/util/utils.py
ktransformers/util/utils.py
+64
-8
No files found.
ktransformers/util/utils.py
View file @
b3a1fcf4
...
...
@@ -11,6 +11,17 @@ from torch import nn
import
itertools
import
time
import
enum
from
transformers
import
(
LogitsProcessorList
,
TemperatureLogitsWarper
,
TopKLogitsWarper
,
TopPLogitsWarper
,
MinPLogitsWarper
,
TypicalLogitsWarper
,
EpsilonLogitsWarper
,
EtaLogitsWarper
,
)
from
ktransformers.util.custom_gguf
import
translate_name_to_gguf
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.operators
import
base_operator
...
...
@@ -126,6 +137,57 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
else
:
module
.
load
()
def
tf_logits_warper
(
generation_config
):
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
used for multinomial sampling.
"""
# instantiate warpers list
warpers
=
LogitsProcessorList
()
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
if
generation_config
.
num_beams
>
1
:
if
isinstance
(
generation_config
.
_eos_token_tensor
,
list
):
min_tokens_to_keep
=
len
(
generation_config
.
_eos_token_tensor
)
+
1
elif
isinstance
(
generation_config
.
_eos_token_tensor
,
torch
.
Tensor
):
min_tokens_to_keep
=
generation_config
.
_eos_token_tensor
.
shape
[
0
]
+
1
else
:
min_tokens_to_keep
=
2
else
:
min_tokens_to_keep
=
1
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if
generation_config
.
temperature
is
not
None
and
generation_config
.
temperature
!=
1.0
:
warpers
.
append
(
TemperatureLogitsWarper
(
generation_config
.
temperature
))
if
generation_config
.
top_k
is
not
None
and
generation_config
.
top_k
!=
0
:
warpers
.
append
(
TopKLogitsWarper
(
top_k
=
generation_config
.
top_k
,
min_tokens_to_keep
=
min_tokens_to_keep
))
if
generation_config
.
top_p
is
not
None
and
generation_config
.
top_p
<
1.0
:
warpers
.
append
(
TopPLogitsWarper
(
top_p
=
generation_config
.
top_p
,
min_tokens_to_keep
=
min_tokens_to_keep
))
if
generation_config
.
min_p
is
not
None
:
# Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
warpers
.
append
(
MinPLogitsWarper
(
min_p
=
generation_config
.
min_p
,
min_tokens_to_keep
=
min_tokens_to_keep
))
if
generation_config
.
typical_p
is
not
None
and
generation_config
.
typical_p
<
1.0
:
warpers
.
append
(
TypicalLogitsWarper
(
mass
=
generation_config
.
typical_p
,
min_tokens_to_keep
=
min_tokens_to_keep
)
)
if
generation_config
.
epsilon_cutoff
is
not
None
and
0.0
<
generation_config
.
epsilon_cutoff
<
1.0
:
warpers
.
append
(
EpsilonLogitsWarper
(
epsilon
=
generation_config
.
epsilon_cutoff
,
min_tokens_to_keep
=
min_tokens_to_keep
)
)
if
generation_config
.
eta_cutoff
is
not
None
and
0.0
<
generation_config
.
eta_cutoff
<
1.0
:
warpers
.
append
(
EtaLogitsWarper
(
epsilon
=
generation_config
.
eta_cutoff
,
min_tokens_to_keep
=
min_tokens_to_keep
,
device
=
device
)
)
# `LogitNormalization` should always be the last logit processor, when present
if
generation_config
.
renormalize_logits
is
True
:
warpers
.
append
(
LogitNormalization
())
return
warpers
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
,
use_cuda_graph
:
bool
=
True
,
mode
=
'normal'
,
force_think
:
bool
=
False
,
chunk_size
=
16384
,
use_flashinfer_mla
=
False
,
num_heads
=
None
,
head_dim_ckv
=
None
,
head_dim_kpe
=
None
,
q_head_dim
=
None
):
...
...
@@ -201,14 +263,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
# change this to modify generate config
#top_k=5, top_p=0.85, temperature=0.1
)
try
:
# transformers==4.43
logits_warper
=
(
model
.
_get_logits_warper
(
generation_config
,
device
=
inputs
.
device
)
)
except
:
logits_warper
=
(
model
.
_get_logits_warper
(
generation_config
)
)
logits_warper
=
tf_logits_warper
(
generation_config
)
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
,
dtype
=
torch
.
int32
)
generated_ids
=
torch
.
zeros
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment