Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
907251c7
Commit
907251c7
authored
Feb 04, 2025
by
Azure
Browse files
done support deepseekv3
parent
f748cd29
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1375 additions
and
542 deletions
+1375
-542
ktransformers/models/modeling_deepseek_v3.py
ktransformers/models/modeling_deepseek_v3.py
+1290
-515
ktransformers/operators/RoPE.py
ktransformers/operators/RoPE.py
+52
-1
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+2
-2
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+3
-3
ktransformers/operators/gate.py
ktransformers/operators/gate.py
+2
-3
ktransformers/operators/models.py
ktransformers/operators/models.py
+15
-9
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
...s/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
+6
-6
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+2
-2
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+3
-1
No files found.
ktransformers/models/modeling_deepseek_v3.py
View file @
907251c7
This diff is collapsed.
Click to expand it.
ktransformers/operators/RoPE.py
View file @
907251c7
...
@@ -23,7 +23,7 @@ from ktransformers.operators.base_operator import BaseInjectedModule
...
@@ -23,7 +23,7 @@ from ktransformers.operators.base_operator import BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.util.utils
import
InferenceState
from
ktransformers.util.utils
import
InferenceState
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
import
torch
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
class
RotaryEmbedding
(
BaseInjectedModule
,
DeepseekV2RotaryEmbedding
):
class
RotaryEmbedding
(
BaseInjectedModule
,
DeepseekV2RotaryEmbedding
):
...
@@ -56,6 +56,57 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
...
@@ -56,6 +56,57 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
)
)
class
RotaryEmbeddingV3
(
BaseInjectedModule
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
# device: str = "cuda",
generate_device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
self
.
generate_device
=
generate_device
self
.
prefill_device
=
prefill_device
@
torch
.
no_grad
()
def
forward
(
self
,
x
,
position_ids
):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded
=
self
.
inv_freq
[
None
,
:,
None
].
float
().
expand
(
position_ids
.
shape
[
0
],
-
1
,
1
)
position_ids_expanded
=
position_ids
[:,
None
,
:].
float
()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type
=
x
.
device
.
type
device_type
=
device_type
if
isinstance
(
device_type
,
str
)
and
device_type
!=
"mps"
else
"cpu"
with
torch
.
autocast
(
device_type
=
device_type
,
enabled
=
False
):
freqs
=
(
inv_freq_expanded
.
float
()
@
position_ids_expanded
.
float
()).
transpose
(
1
,
2
)
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
cos
=
emb
.
cos
()
sin
=
emb
.
sin
()
return
cos
.
to
(
dtype
=
x
.
dtype
),
sin
.
to
(
dtype
=
x
.
dtype
)
def
load
(
self
):
self
.
_init
(
dim
=
self
.
config
.
qk_rope_head_dim
,
max_position_embeddings
=
self
.
config
.
max_position_embeddings
,
base
=
self
.
config
.
rope_theta
,
device
=
self
.
device
,
)
def
_init
(
self
,
dim
,
max_position_embeddings
,
base
,
device
,
scaling_factor
=
1.0
):
self
.
scaling_factor
=
scaling_factor
self
.
dim
=
dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
inv_freq
=
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
dtype
=
torch
.
int64
).
float
().
to
(
device
)
/
self
.
dim
))
# self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self
.
max_seq_len_cached
=
max_position_embeddings
class
RotaryEmbeddingV2
(
BaseInjectedModule
,
LlamaRotaryEmbedding
):
class
RotaryEmbeddingV2
(
BaseInjectedModule
,
LlamaRotaryEmbedding
):
def
__init__
(
def
__init__
(
self
,
self
,
...
...
ktransformers/operators/attention.py
View file @
907251c7
...
@@ -151,7 +151,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
...
@@ -151,7 +151,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
attn_weights
return
attn_output
,
attn_weights
,
past_key_value
def
forward
(
def
forward
(
self
,
self
,
...
@@ -220,7 +220,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
...
@@ -220,7 +220,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
attn_output
=
torch
.
cat
((
attn_output
,
cur_output
),
dim
=-
2
)
attn_output
=
torch
.
cat
((
attn_output
,
cur_output
),
dim
=-
2
)
attn_weight
=
torch
.
cat
((
attn_weight
,
cur_attn_weight
),
dim
=-
2
)
attn_weight
=
torch
.
cat
((
attn_weight
,
cur_attn_weight
),
dim
=-
2
)
return
attn_output
,
attn_weight
return
attn_output
,
attn_weight
,
past_key_value
class
KDeepseekV2Attention
(
BaseInjectedModule
,
DeepseekV2Attention
):
class
KDeepseekV2Attention
(
BaseInjectedModule
,
DeepseekV2Attention
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
"""Multi-headed attention from 'Attention Is All You Need' paper"""
...
...
ktransformers/operators/experts.py
View file @
907251c7
...
@@ -734,7 +734,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
...
@@ -734,7 +734,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
identity
=
hidden_states
identity
=
hidden_states
orig_shape
=
hidden_states
.
shape
orig_shape
=
hidden_states
.
shape
sequence_length
=
orig_shape
[
1
]
sequence_length
=
orig_shape
[
1
]
topk_idx
,
topk_weight
,
router_logits
=
self
.
gate
(
hidden_states
)
topk_idx
,
topk_weight
=
self
.
gate
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
# only for generate phase
# only for generate phase
...
@@ -745,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
...
@@ -745,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y
=
self
.
experts
.
generate_experts
.
sync_for_one_decode
().
unsqueeze
(
0
)
y
=
self
.
experts
.
generate_experts
.
sync_for_one_decode
().
unsqueeze
(
0
)
y
+=
y_
y
+=
y_
y
.
resize_
(
*
orig_shape
)
y
.
resize_
(
*
orig_shape
)
return
y
,
router_logits
return
y
if
self
.
config
.
n_shared_experts
is
not
None
:
if
self
.
config
.
n_shared_experts
is
not
None
:
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
...
@@ -768,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
...
@@ -768,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
)
)
if
self
.
config
.
n_shared_experts
is
not
None
:
if
self
.
config
.
n_shared_experts
is
not
None
:
y
+=
y_
y
+=
y_
return
y
,
router_logits
return
y
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
ktransformers/operators/gate.py
View file @
907251c7
...
@@ -16,9 +16,6 @@ from cpuinfer_ext.moe import MOEConfig, MOE
...
@@ -16,9 +16,6 @@ from cpuinfer_ext.moe import MOEConfig, MOE
import
ctypes
import
ctypes
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3TopkRouter
from
ktransformers.util.utils
import
InferenceState
from
ktransformers.server.config.config
import
Config
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
...
@@ -102,6 +99,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
...
@@ -102,6 +99,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
):
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
KMoEGateBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
KMoEGateBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
self
.
generate_device
=
generate_device
self
.
prefill_device
=
prefill_device
def
forward
(
self
,
hidden_states
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
)
->
torch
.
Tensor
:
return
self
.
orig_module
.
forward
(
hidden_states
)
return
self
.
orig_module
.
forward
(
hidden_states
)
...
...
ktransformers/operators/models.py
View file @
907251c7
...
@@ -625,6 +625,13 @@ class KDeepseekV2Model(BaseInjectedModule):
...
@@ -625,6 +625,13 @@ class KDeepseekV2Model(BaseInjectedModule):
if
use_legacy_cache
:
if
use_legacy_cache
:
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values_length
=
past_key_values
.
get_usable_length
(
seq_length
)
past_key_values_length
=
past_key_values
.
get_usable_length
(
seq_length
)
if
inputs_embeds
is
None
:
org_device
=
input_ids
.
device
# TODO move to embed_tokens's device, not hard code to cpu
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
).
to
(
org_device
)
input_ids
=
input_ids
.
to
(
org_device
)
if
cache_position
is
None
:
if
cache_position
is
None
:
past_seen_tokens
=
(
past_seen_tokens
=
(
...
@@ -639,13 +646,6 @@ class KDeepseekV2Model(BaseInjectedModule):
...
@@ -639,13 +646,6 @@ class KDeepseekV2Model(BaseInjectedModule):
if
position_ids
is
None
:
if
position_ids
is
None
:
position_ids
=
cache_position
.
unsqueeze
(
0
)
position_ids
=
cache_position
.
unsqueeze
(
0
)
if
inputs_embeds
is
None
:
org_device
=
input_ids
.
device
# TODO move to embed_tokens's device, not hard code to cpu
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
).
to
(
org_device
)
input_ids
=
input_ids
.
to
(
org_device
)
if
per_layer_prefill_flag
:
if
per_layer_prefill_flag
:
causal_mask
=
None
causal_mask
=
None
else
:
else
:
...
@@ -717,6 +717,8 @@ class KDeepseekV2Model(BaseInjectedModule):
...
@@ -717,6 +717,8 @@ class KDeepseekV2Model(BaseInjectedModule):
self
.
load_layer_to
(
decoder_layer
,
InferenceState
.
PREFILL
)
self
.
load_layer_to
(
decoder_layer
,
InferenceState
.
PREFILL
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
t4
=
time
.
time
()
t4
=
time
.
time
()
# with open("log.txt", "a") as f:
# f.write(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n")
layer_outputs
=
decoder_layer
(
layer_outputs
=
decoder_layer
(
hidden_states
,
hidden_states
,
attention_mask
=
causal_mask
,
attention_mask
=
causal_mask
,
...
@@ -739,13 +741,17 @@ class KDeepseekV2Model(BaseInjectedModule):
...
@@ -739,13 +741,17 @@ class KDeepseekV2Model(BaseInjectedModule):
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
layer_outputs
[
0
]
# @@@@@@@ TODO open this notes, tmp close to fit deepseekv3
# @@@@@@@ TODO open this notes, tmp close to fit deepseekv3
#
if use_cache:
if
use_cache
:
#
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
next_decoder_cache
=
layer_outputs
[
2
if
output_attentions
else
1
]
if
output_attentions
:
if
output_attentions
:
all_self_attns
+=
(
layer_outputs
[
1
],)
all_self_attns
+=
(
layer_outputs
[
1
],)
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
# with open("log.txt", "a") as f:
# f.write(f"@@@After layers\n")
# f.write(f"hidden_states={hidden_states}\n")
# f.write(f"hidden_states.shape={hidden_states.shape}\n")
if
per_layer_prefill_flag
:
if
per_layer_prefill_flag
:
t6
=
time
.
time
()
t6
=
time
.
time
()
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
View file @
907251c7
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
replace
:
class
:
ktransformers.operators.RoPE.
DeepSeekV3Yarn
RotaryEmbedding
class
:
ktransformers.operators.RoPE.RotaryEmbedding
V3
kwargs
:
kwargs
:
generate_device
:
"
cuda:0"
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
."
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
replace
:
class
:
ktransformers.operators.RoPE.
DeepSeekV3Yarn
RotaryEmbedding
class
:
ktransformers.operators.RoPE.RotaryEmbedding
V3
kwargs
:
kwargs
:
generate_device
:
"
cuda:1"
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
...
@@ -64,7 +64,7 @@
...
@@ -64,7 +64,7 @@
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.gate$"
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.
DeepseekV3TopkRou
te
r
class
:
ktransformers.models.modeling_deepseek_v3.
MoEGa
te
replace
:
replace
:
class
:
ktransformers.operators.gate.KMoEGate
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
kwargs
:
...
@@ -72,7 +72,7 @@
...
@@ -72,7 +72,7 @@
prefill_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.gate$"
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.
DeepseekV3TopkRou
te
r
class
:
ktransformers.models.modeling_deepseek_v3.
MoEGa
te
replace
:
replace
:
class
:
ktransformers.operators.gate.KMoEGate
# mlp module with custom forward function
class
:
ktransformers.operators.gate.KMoEGate
# mlp module with custom forward function
kwargs
:
kwargs
:
...
@@ -106,14 +106,14 @@
...
@@ -106,14 +106,14 @@
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.self_attn$"
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.self_attn$"
replace
:
replace
:
class
:
ktransformers.operators.attention.KDeepseekV
3
Attention
# optimized MLA implementation
class
:
ktransformers.operators.attention.KDeepseekV
2
Attention
# optimized MLA implementation
kwargs
:
kwargs
:
generate_device
:
"
cuda:0"
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.self_attn$"
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.self_attn$"
replace
:
replace
:
class
:
ktransformers.operators.attention.KDeepseekV
3
Attention
# optimized MLA implementation
class
:
ktransformers.operators.attention.KDeepseekV
2
Attention
# optimized MLA implementation
kwargs
:
kwargs
:
generate_device
:
"
cuda:1"
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
...
...
ktransformers/server/backend/interfaces/ktransformers.py
View file @
907251c7
...
@@ -24,7 +24,7 @@ class KTransformersInterface(TransformersInterface):
...
@@ -24,7 +24,7 @@ class KTransformersInterface(TransformersInterface):
self
.
args
=
args
self
.
args
=
args
torch
.
set_default_dtype
(
torch
.
bfloat16
)
torch
.
set_default_dtype
(
torch
.
bfloat16
)
torch
.
set_grad_enabled
(
False
)
torch
.
set_grad_enabled
(
False
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
,
device
=
args
.
device
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
,
device
=
args
.
device
,
trust_remote_code
=
True
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
True
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
True
)
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
config
.
_attn_implementation
=
"flash_attention_2"
config
.
_attn_implementation
=
"flash_attention_2"
...
@@ -99,7 +99,7 @@ class KTransformersInterface(TransformersInterface):
...
@@ -99,7 +99,7 @@ class KTransformersInterface(TransformersInterface):
if
self
.
use_static_cache
:
if
self
.
use_static_cache
:
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
torch_device
)
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
torch_device
)
logits
=
self
.
model
(
logits
=
self
.
model
(
self
.
current_ids
,
self
.
current_ids
.
to
(
torch_device
)
,
cache_position
=
self
.
active_cache_position
,
cache_position
=
self
.
active_cache_position
,
past_key_values
=
self
.
cache
,
past_key_values
=
self
.
cache
,
attention_mask
=
mask
,
attention_mask
=
mask
,
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
907251c7
...
@@ -198,7 +198,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -198,7 +198,7 @@ class TransformersInterface(BackendInterfaceBase):
return
self
.
streamer
.
put
(
new_tokens
)
return
self
.
streamer
.
put
(
new_tokens
)
def
logits_to_token
(
self
,
logits
:
torch
.
Tensor
):
def
logits_to_token
(
self
,
logits
:
torch
.
Tensor
):
logits
=
logits
/
self
.
args
.
temperature
logits
=
logits
/
self
.
args
.
temperature
if
self
.
args
.
temperature
!=
0
else
logits
for
token_idx
in
self
.
ever_generated_ids
:
for
token_idx
in
self
.
ever_generated_ids
:
if
logits
[
token_idx
]
<
0
:
if
logits
[
token_idx
]
<
0
:
...
@@ -318,7 +318,9 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -318,7 +318,9 @@ class TransformersInterface(BackendInterfaceBase):
if
isinstance
(
local_messages
,
List
):
if
isinstance
(
local_messages
,
List
):
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
elif
isinstance
(
local_messages
,
str
):
elif
isinstance
(
local_messages
,
str
):
#local_messages = local_messages[0]['content']
input_ids
=
self
.
tokenize_prompt
(
local_messages
)
input_ids
=
self
.
tokenize_prompt
(
local_messages
)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else
:
else
:
raise
ValueError
(
"local_messages should be List or str"
)
raise
ValueError
(
"local_messages should be List or str"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment