Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
3f14e311
Unverified
Commit
3f14e311
authored
May 07, 2025
by
Yaochen Han
Committed by
GitHub
May 07, 2025
Browse files
Merge pull request #1247 from aubreyli/_get_logits_warper
ktransformers/utils: fix _get_logits_warper error
parents
7530491f
b3a1fcf4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
64 additions
and
8 deletions
+64
-8
ktransformers/util/utils.py
ktransformers/util/utils.py
+64
-8
No files found.
ktransformers/util/utils.py
View file @
3f14e311
...
@@ -11,6 +11,17 @@ from torch import nn
...
@@ -11,6 +11,17 @@ from torch import nn
import
itertools
import
itertools
import
time
import
time
import
enum
import
enum
from
transformers
import
(
LogitsProcessorList
,
TemperatureLogitsWarper
,
TopKLogitsWarper
,
TopPLogitsWarper
,
MinPLogitsWarper
,
TypicalLogitsWarper
,
EpsilonLogitsWarper
,
EtaLogitsWarper
,
)
from
ktransformers.util.custom_gguf
import
translate_name_to_gguf
from
ktransformers.util.custom_gguf
import
translate_name_to_gguf
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.operators
import
base_operator
from
ktransformers.operators
import
base_operator
...
@@ -126,6 +137,57 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
...
@@ -126,6 +137,57 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
else
:
else
:
module
.
load
()
module
.
load
()
def
tf_logits_warper
(
generation_config
):
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
used for multinomial sampling.
"""
# instantiate warpers list
warpers
=
LogitsProcessorList
()
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
if
generation_config
.
num_beams
>
1
:
if
isinstance
(
generation_config
.
_eos_token_tensor
,
list
):
min_tokens_to_keep
=
len
(
generation_config
.
_eos_token_tensor
)
+
1
elif
isinstance
(
generation_config
.
_eos_token_tensor
,
torch
.
Tensor
):
min_tokens_to_keep
=
generation_config
.
_eos_token_tensor
.
shape
[
0
]
+
1
else
:
min_tokens_to_keep
=
2
else
:
min_tokens_to_keep
=
1
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if
generation_config
.
temperature
is
not
None
and
generation_config
.
temperature
!=
1.0
:
warpers
.
append
(
TemperatureLogitsWarper
(
generation_config
.
temperature
))
if
generation_config
.
top_k
is
not
None
and
generation_config
.
top_k
!=
0
:
warpers
.
append
(
TopKLogitsWarper
(
top_k
=
generation_config
.
top_k
,
min_tokens_to_keep
=
min_tokens_to_keep
))
if
generation_config
.
top_p
is
not
None
and
generation_config
.
top_p
<
1.0
:
warpers
.
append
(
TopPLogitsWarper
(
top_p
=
generation_config
.
top_p
,
min_tokens_to_keep
=
min_tokens_to_keep
))
if
generation_config
.
min_p
is
not
None
:
# Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
warpers
.
append
(
MinPLogitsWarper
(
min_p
=
generation_config
.
min_p
,
min_tokens_to_keep
=
min_tokens_to_keep
))
if
generation_config
.
typical_p
is
not
None
and
generation_config
.
typical_p
<
1.0
:
warpers
.
append
(
TypicalLogitsWarper
(
mass
=
generation_config
.
typical_p
,
min_tokens_to_keep
=
min_tokens_to_keep
)
)
if
generation_config
.
epsilon_cutoff
is
not
None
and
0.0
<
generation_config
.
epsilon_cutoff
<
1.0
:
warpers
.
append
(
EpsilonLogitsWarper
(
epsilon
=
generation_config
.
epsilon_cutoff
,
min_tokens_to_keep
=
min_tokens_to_keep
)
)
if
generation_config
.
eta_cutoff
is
not
None
and
0.0
<
generation_config
.
eta_cutoff
<
1.0
:
warpers
.
append
(
EtaLogitsWarper
(
epsilon
=
generation_config
.
eta_cutoff
,
min_tokens_to_keep
=
min_tokens_to_keep
,
device
=
device
)
)
# `LogitNormalization` should always be the last logit processor, when present
if
generation_config
.
renormalize_logits
is
True
:
warpers
.
append
(
LogitNormalization
())
return
warpers
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
,
use_cuda_graph
:
bool
=
True
,
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
,
use_cuda_graph
:
bool
=
True
,
mode
=
'normal'
,
force_think
:
bool
=
False
,
chunk_size
=
16384
,
use_flashinfer_mla
=
False
,
mode
=
'normal'
,
force_think
:
bool
=
False
,
chunk_size
=
16384
,
use_flashinfer_mla
=
False
,
num_heads
=
None
,
head_dim_ckv
=
None
,
head_dim_kpe
=
None
,
q_head_dim
=
None
):
num_heads
=
None
,
head_dim_ckv
=
None
,
head_dim_kpe
=
None
,
q_head_dim
=
None
):
...
@@ -201,14 +263,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -201,14 +263,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
# change this to modify generate config
# change this to modify generate config
#top_k=5, top_p=0.85, temperature=0.1
#top_k=5, top_p=0.85, temperature=0.1
)
)
try
:
# transformers==4.43
logits_warper
=
(
logits_warper
=
tf_logits_warper
(
generation_config
)
model
.
_get_logits_warper
(
generation_config
,
device
=
inputs
.
device
)
)
except
:
logits_warper
=
(
model
.
_get_logits_warper
(
generation_config
)
)
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
,
dtype
=
torch
.
int32
)
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
,
dtype
=
torch
.
int32
)
generated_ids
=
torch
.
zeros
(
generated_ids
=
torch
.
zeros
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment