Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
f873558a
Commit
f873558a
authored
Feb 01, 2025
by
Azure
Browse files
update rope calculation; update modeling.py; update gate for moe
parent
5a50b346
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
406 additions
and
416 deletions
+406
-416
ktransformers/configs/config.yaml
ktransformers/configs/config.yaml
+1
-1
ktransformers/local_chat.py
ktransformers/local_chat.py
+2
-2
ktransformers/models/configuration_deepseek_v3.py
ktransformers/models/configuration_deepseek_v3.py
+51
-47
ktransformers/models/custom_cache.py
ktransformers/models/custom_cache.py
+4
-0
ktransformers/models/modeling_deepseek_v3.py
ktransformers/models/modeling_deepseek_v3.py
+230
-290
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+3
-2
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+5
-4
ktransformers/operators/gate.py
ktransformers/operators/gate.py
+3
-4
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
...s/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
+4
-4
ktransformers/server/config/config.py
ktransformers/server/config/config.py
+1
-1
ktransformers/util/modeling_rope_utils.py
ktransformers/util/modeling_rope_utils.py
+102
-61
No files found.
ktransformers/configs/config.yaml
View file @
f873558a
...
@@ -54,4 +54,4 @@ long_context:
...
@@ -54,4 +54,4 @@ long_context:
token_step
:
token_step
:
local_chat
:
local_chat
:
prompt_file
:
"
./ktransformers/p.txt"
prompt_file
:
"
"
\ No newline at end of file
\ No newline at end of file
ktransformers/local_chat.py
View file @
f873558a
...
@@ -15,7 +15,7 @@ from ktransformers.server.args import ArgumentParser
...
@@ -15,7 +15,7 @@ from ktransformers.server.args import ArgumentParser
from
ktransformers.models.modeling_deepseek
import
DeepseekV2ForCausalLM
from
ktransformers.models.modeling_deepseek
import
DeepseekV2ForCausalLM
from
ktransformers.models.modeling_deepseekv3
import
DeepseekV3ForCausalLM
from
ktransformers.models.modeling_deepseek
_
v3
import
DeepseekV3ForCausalLM
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeForCausalLM
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeForCausalLM
from
ktransformers.models.modeling_llama
import
LlamaForCausalLM
from
ktransformers.models.modeling_llama
import
LlamaForCausalLM
from
ktransformers.models.modeling_mixtral
import
MixtralForCausalLM
from
ktransformers.models.modeling_mixtral
import
MixtralForCausalLM
...
@@ -78,7 +78,7 @@ def local_chat():
...
@@ -78,7 +78,7 @@ def local_chat():
else
:
else
:
content
+=
line
+
"
\n
"
content
+=
line
+
"
\n
"
if
content
==
""
:
if
content
==
""
:
if
config
.
prompt_file
==
None
or
config
.
prompt_file
==
""
:
if
not
config
.
prompt_file
:
content
=
"hi"
content
=
"hi"
else
:
else
:
content
=
open
(
config
.
prompt_file
,
"r"
).
read
()
content
=
open
(
config
.
prompt_file
,
"r"
).
read
()
...
...
ktransformers/models/configuration_deepseekv3.py
→
ktransformers/models/configuration_deepseek
_
v3.py
View file @
f873558a
...
@@ -14,19 +14,25 @@
...
@@ -14,19 +14,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
DeepSeekV3 model configuration
"""
"""DeepSeekV3 model configuration"""
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.modeling_rope_utils
import
rope_config_validation
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{}
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{}
class
DeepseekV3Config
(
PretrainedConfig
):
class
DeepseekV3Config
(
PretrainedConfig
):
r
"""
r
"""
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the DeepSeek-V3.
defaults will yield a similar configuration to that of the DeepSeek-V3.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
documentation from [`PretrainedConfig`] for more information.
Args:
Args:
vocab_size (`int`, *optional*, defaults to 129280):
vocab_size (`int`, *optional*, defaults to 129280):
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
...
@@ -39,8 +45,6 @@ class DeepseekV3Config(PretrainedConfig):
...
@@ -39,8 +45,6 @@ class DeepseekV3Config(PretrainedConfig):
Dimension of the MoE representations.
Dimension of the MoE representations.
num_hidden_layers (`int`, *optional*, defaults to 61):
num_hidden_layers (`int`, *optional*, defaults to 61):
Number of hidden layers in the Transformer decoder.
Number of hidden layers in the Transformer decoder.
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
Number of nextn predict layers in the DeepSeekV3 Model.
num_attention_heads (`int`, *optional*, defaults to 128):
num_attention_heads (`int`, *optional*, defaults to 128):
Number of attention heads for each attention layer in the Transformer decoder.
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 128):
num_key_value_heads (`int`, *optional*, defaults to 128):
...
@@ -52,38 +56,35 @@ class DeepseekV3Config(PretrainedConfig):
...
@@ -52,38 +56,35 @@ class DeepseekV3Config(PretrainedConfig):
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
`num_attention_heads`.
n_shared_experts (`int`, *optional*, defaults to 1):
n_shared_experts (`int`, *optional*, defaults to 1):
Number of shared experts
, None means dense model
.
Number of shared experts.
n_routed_experts (`int`, *optional*, defaults to 256):
n_routed_experts (`int`, *optional*, defaults to 256):
Number of routed experts, None means dense model.
Number of routed experts.
ep_size (`<fill_type>`, *optional*, defaults to 1): <fill_docstring>
routed_scaling_factor (`float`, *optional*, defaults to 2.5):
routed_scaling_factor (`float`, *optional*, defaults to 2.5):
Scaling factor or routed experts.
Scaling factor or routed experts.
kv_lora_rank (`<fill_type>`, *optional*, defaults to 512): <fill_docstring>
kv_lora_rank (`int`, *optional*, defaults to 512):
q_lora_rank (`<fill_type>`, *optional*, defaults to 1536): <fill_docstring>
Rank of the LoRA matrices for key and value projections.
qk_rope_head_dim (`<fill_type>`, *optional*, defaults to 64): <fill_docstring>
q_lora_rank (`int`, *optional*, defaults to 1536):
v_head_dim (`<fill_type>`, *optional*, defaults to 128): <fill_docstring>
Rank of the LoRA matrices for query projections.
qk_nope_head_dim (`<fill_type>`, *optional*, defaults to 128): <fill_docstring>
qk_rope_head_dim (`int`, *optional*, defaults to 64):
topk_method (`str`, *optional*, defaults to `"noaux_tc"`):
Dimension of the query/key heads that use rotary position embeddings.
Topk method used in routed gate.
v_head_dim (`int`, *optional*, defaults to 128):
Dimension of the value heads.
qk_nope_head_dim (`int`, *optional*, defaults to 128):
Dimension of the query/key heads that don't use rotary position embeddings.
n_group (`int`, *optional*, defaults to 8):
n_group (`int`, *optional*, defaults to 8):
Number of groups for routed experts.
Number of groups for routed experts.
topk_group (`int`, *optional*, defaults to 4):
topk_group (`int`, *optional*, defaults to 4):
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
num_experts_per_tok (`int`, *optional*, defaults to 8):
num_experts_per_tok (`int`, *optional*, defaults to 8):
Number of selected experts, None means dense model.
Number of selected experts, None means dense model.
moe_layer_freq (`int`, *optional*, defaults to 1):
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
first_k_dense_replace (`int`, *optional*, defaults to 3):
first_k_dense_replace (`int`, *optional*, defaults to 3):
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
\--k dense layers--/
\--k dense layers--/
norm_topk_prob (`bool`, *optional*, defaults to `True`):
norm_topk_prob (`bool`, *optional*, defaults to `True`):
Whether to normalize the weights of the routed experts.
Whether to normalize the weights of the routed experts.
scoring_func (`str`, *optional*, defaults to `"sigmoid"`):
Method of computing expert weights.
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
Auxiliary loss weight coefficient.
Auxiliary loss weight coefficient.
Whether to compute the auxiliary loss for each individual sample.
Whether to compute the auxiliary loss for each individual sample.
seq_aux (`<fill_type>`, *optional*, defaults to `True`): <fill_docstring>
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 4096):
max_position_embeddings (`int`, *optional*, defaults to 4096):
...
@@ -119,46 +120,49 @@ class DeepseekV3Config(PretrainedConfig):
...
@@ -119,46 +120,49 @@ class DeepseekV3Config(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
The dropout ratio for the attention probabilities.
```python
```python
>>> from transformers import DeepseekV3Model, DeepseekV3Config
>>> from transformers import DeepseekV3Model, DeepseekV3Config
>>> # Initializing a Deepseek-V3 style configuration
>>> # Initializing a Deepseek-V3 style configuration
>>> configuration = DeepseekV3Config()
>>> configuration = DeepseekV3Config()
>>> # Accessing the model configuration
>>> # Accessing the model configuration
>>> configuration = model.config
>>> configuration = model.config
```"""
```"""
model_type
=
"deepseek_v3"
model_type
=
"deepseek_v3"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
keys_to_ignore_at_inference
=
[
"past_key_values"
]
# Default tensor parallel plan for base model `DeepseekV3Model`
base_model_tp_plan
=
{
"layers.*.gate_proj"
:
"colwise"
,
"layers.*.up_proj"
:
"colwise"
,
"layers.*.down_proj"
:
"rowwise"
,
}
def
__init__
(
def
__init__
(
self
,
self
,
vocab_size
=
129280
,
vocab_size
=
129280
,
hidden_size
=
7168
,
hidden_size
=
7168
,
intermediate_size
=
18432
,
intermediate_size
=
18432
,
moe_intermediate_size
=
2048
,
moe_intermediate_size
=
2048
,
num_hidden_layers
=
61
,
num_hidden_layers
=
61
,
num_nextn_predict_layers
=
1
,
num_attention_heads
=
128
,
num_attention_heads
=
128
,
num_key_value_heads
=
128
,
num_key_value_heads
=
128
,
n_shared_experts
=
1
,
n_shared_experts
=
1
,
n_routed_experts
=
256
,
n_routed_experts
=
256
,
ep_size
=
1
,
routed_scaling_factor
=
2.5
,
routed_scaling_factor
=
2.5
,
kv_lora_rank
=
512
,
kv_lora_rank
=
512
,
q_lora_rank
=
1536
,
q_lora_rank
=
1536
,
qk_rope_head_dim
=
64
,
qk_rope_head_dim
=
64
,
v_head_dim
=
128
,
v_head_dim
=
128
,
qk_nope_head_dim
=
128
,
qk_nope_head_dim
=
128
,
n_group
=
8
,
topk_method
=
'noaux_tc'
,
topk_group
=
4
,
n_group
=
8
,
num_experts_per_tok
=
8
,
topk_group
=
4
,
first_k_dense_replace
=
3
,
num_experts_per_tok
=
8
,
norm_topk_prob
=
True
,
moe_layer_freq
=
1
,
aux_loss_alpha
=
0.001
,
first_k_dense_replace
=
3
,
norm_topk_prob
=
True
,
scoring_func
=
'sigmoid'
,
aux_loss_alpha
=
0.001
,
seq_aux
=
True
,
hidden_act
=
"silu"
,
hidden_act
=
"silu"
,
max_position_embeddings
=
4096
,
max_position_embeddings
=
4096
,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
...
@@ -173,7 +177,6 @@ class DeepseekV3Config(PretrainedConfig):
...
@@ -173,7 +177,6 @@ class DeepseekV3Config(PretrainedConfig):
rope_scaling
=
None
,
rope_scaling
=
None
,
attention_bias
=
False
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
attention_dropout
=
0.0
,
mlp_bias
=
False
,
**
kwargs
,
**
kwargs
,
):
):
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
...
@@ -182,27 +185,24 @@ class DeepseekV3Config(PretrainedConfig):
...
@@ -182,27 +185,24 @@ class DeepseekV3Config(PretrainedConfig):
self
.
intermediate_size
=
intermediate_size
self
.
intermediate_size
=
intermediate_size
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_nextn_predict_layers
=
num_nextn_predict_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_heads
=
num_attention_heads
self
.
n_shared_experts
=
n_shared_experts
self
.
n_shared_experts
=
n_shared_experts
self
.
n_routed_experts
=
n_routed_experts
self
.
n_routed_experts
=
n_routed_experts
self
.
ep_size
=
ep_size
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
kv_lora_rank
=
kv_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
q_lora_rank
=
q_lora_rank
self
.
q_lora_rank
=
q_lora_rank
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
topk_method
=
topk_method
self
.
q_head_dim
=
qk_nope_head_dim
+
qk_rope_head_dim
self
.
head_dim
=
qk_rope_head_dim
self
.
n_group
=
n_group
self
.
n_group
=
n_group
self
.
topk_group
=
topk_group
self
.
topk_group
=
topk_group
self
.
num_experts_per_tok
=
num_experts_per_tok
self
.
num_experts_per_tok
=
num_experts_per_tok
self
.
moe_layer_freq
=
moe_layer_freq
self
.
first_k_dense_replace
=
first_k_dense_replace
self
.
first_k_dense_replace
=
first_k_dense_replace
self
.
norm_topk_prob
=
norm_topk_prob
self
.
norm_topk_prob
=
norm_topk_prob
self
.
scoring_func
=
scoring_func
self
.
aux_loss_alpha
=
aux_loss_alpha
self
.
aux_loss_alpha
=
aux_loss_alpha
self
.
seq_aux
=
seq_aux
# for backward compatibility
# for backward compatibility
if
num_key_value_heads
is
None
:
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
num_key_value_heads
=
num_attention_heads
...
@@ -217,7 +217,11 @@ class DeepseekV3Config(PretrainedConfig):
...
@@ -217,7 +217,11 @@ class DeepseekV3Config(PretrainedConfig):
self
.
rope_scaling
=
rope_scaling
self
.
rope_scaling
=
rope_scaling
self
.
attention_bias
=
attention_bias
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
attention_dropout
=
attention_dropout
self
.
mlp_bias
=
mlp_bias
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if
self
.
rope_scaling
is
not
None
and
"type"
in
self
.
rope_scaling
:
self
.
rope_scaling
[
"rope_type"
]
=
self
.
rope_scaling
[
"type"
]
rope_config_validation
(
self
)
super
().
__init__
(
super
().
__init__
(
pad_token_id
=
pad_token_id
,
pad_token_id
=
pad_token_id
,
...
...
ktransformers/models/custom_cache.py
View file @
f873558a
...
@@ -135,3 +135,7 @@ class StaticCache(transformers.StaticCache):
...
@@ -135,3 +135,7 @@ class StaticCache(transformers.StaticCache):
# In-place ops prevent breaking the static address
# In-place ops prevent breaking the static address
self
.
key_cache
[
layer_idx
].
zero_
()
self
.
key_cache
[
layer_idx
].
zero_
()
self
.
value_cache
[
layer_idx
].
zero_
()
self
.
value_cache
[
layer_idx
].
zero_
()
def
get_max_cache_shape
(
self
)
->
Tuple
[
int
,
int
,
int
,
int
]:
"""Returns the maximum shape of the cache."""
return
self
.
max_cache_len
\ No newline at end of file
ktransformers/models/modeling_deepseekv3.py
→
ktransformers/models/modeling_deepseek
_
v3.py
View file @
f873558a
This diff is collapsed.
Click to expand it.
ktransformers/operators/attention.py
View file @
f873558a
...
@@ -13,7 +13,8 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
...
@@ -13,7 +13,8 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.models.modeling_llama
import
LlamaRotaryEmbedding
from
ktransformers.models.modeling_llama
import
LlamaRotaryEmbedding
from
ktransformers.models.modeling_deepseek
import
DeepseekV2Attention
,
apply_rotary_pos_emb
from
ktransformers.models.modeling_deepseek
import
DeepseekV2Attention
,
apply_rotary_pos_emb
from
ktransformers.models.modeling_deepseekv3
import
DeepseekV3Attention
,
apply_rotary_pos_emb
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3Attention
from
ktransformers.models.modeling_deepseek_v3
import
apply_rotary_pos_emb
as
apply_rotary_pos_emb_v3
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.util.custom_gguf
import
GGUFLoader
...
@@ -95,7 +96,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
...
@@ -95,7 +96,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
_v3
(
q_pe
,
k_pe
,
cos
,
sin
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
...
...
ktransformers/operators/experts.py
View file @
f873558a
...
@@ -519,7 +519,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
...
@@ -519,7 +519,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
from
ktransformers.models.modeling_deepseek
import
DeepseekV2MoE
from
ktransformers.models.modeling_deepseek
import
DeepseekV2MoE
from
ktransformers.models.modeling_deepseekv3
import
DeepseekV3MoE
from
ktransformers.models.modeling_deepseek
_
v3
import
DeepseekV3MoE
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeSparseMoeBlock
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeSparseMoeBlock
from
ktransformers.models.modeling_mixtral
import
MixtralSparseMoeBlock
from
ktransformers.models.modeling_mixtral
import
MixtralSparseMoeBlock
...
@@ -734,9 +734,10 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
...
@@ -734,9 +734,10 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
identity
=
hidden_states
identity
=
hidden_states
orig_shape
=
hidden_states
.
shape
orig_shape
=
hidden_states
.
shape
sequence_length
=
orig_shape
[
1
]
sequence_length
=
orig_shape
[
1
]
topk_idx
,
topk_weight
=
self
.
gate
(
hidden_states
)
topk_idx
,
topk_weight
,
router_logits
=
self
.
gate
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
# only for generate phase
if
sequence_length
==
1
and
hasattr
(
self
.
experts
.
generate_experts
,
"submit_for_one_decode"
)
and
torch
.
cuda
.
is_current_stream_capturing
():
if
sequence_length
==
1
and
hasattr
(
self
.
experts
.
generate_experts
,
"submit_for_one_decode"
)
and
torch
.
cuda
.
is_current_stream_capturing
():
self
.
experts
.
generate_experts
.
submit_for_one_decode
(
hidden_states
[
0
],
topk_idx
[
0
],
topk_weight
[
0
])
self
.
experts
.
generate_experts
.
submit_for_one_decode
(
hidden_states
[
0
],
topk_idx
[
0
],
topk_weight
[
0
])
if
self
.
config
.
n_shared_experts
is
not
None
:
if
self
.
config
.
n_shared_experts
is
not
None
:
...
@@ -744,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
...
@@ -744,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y
=
self
.
experts
.
generate_experts
.
sync_for_one_decode
().
unsqueeze
(
0
)
y
=
self
.
experts
.
generate_experts
.
sync_for_one_decode
().
unsqueeze
(
0
)
y
+=
y_
y
+=
y_
y
.
resize_
(
*
orig_shape
)
y
.
resize_
(
*
orig_shape
)
return
y
return
y
,
router_logits
if
self
.
config
.
n_shared_experts
is
not
None
:
if
self
.
config
.
n_shared_experts
is
not
None
:
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
...
@@ -767,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
...
@@ -767,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
)
)
if
self
.
config
.
n_shared_experts
is
not
None
:
if
self
.
config
.
n_shared_experts
is
not
None
:
y
+=
y_
y
+=
y_
return
y
return
y
,
router_logits
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
ktransformers/operators/gate.py
View file @
f873558a
...
@@ -16,7 +16,7 @@ from cpuinfer_ext.moe import MOEConfig, MOE
...
@@ -16,7 +16,7 @@ from cpuinfer_ext.moe import MOEConfig, MOE
import
ctypes
import
ctypes
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.models.modeling_deepseekv3
import
MoEGa
te
from
ktransformers.models.modeling_deepseek
_
v3
import
DeepseekV3TopkRou
te
r
from
ktransformers.util.utils
import
InferenceState
from
ktransformers.util.utils
import
InferenceState
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.config.config
import
Config
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
...
@@ -118,11 +118,10 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
...
@@ -118,11 +118,10 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
else
:
else
:
raise
ValueError
(
"Invalid weight type"
)
raise
ValueError
(
"Invalid weight type"
)
self
.
orig_module
.
weight
=
self
.
orig_module
.
weight
.
to
(
device
)
self
.
orig_module
.
weight
=
self
.
orig_module
.
weight
.
to
(
device
)
if
self
.
topk_method
==
"noaux_tc"
:
self
.
orig_module
.
e_score_correction_bias
=
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
)
self
.
orig_module
.
e_score_correction_bias
=
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
)
def
unload
(
self
):
def
unload
(
self
):
if
self
.
weight
is
not
None
:
if
self
.
weight
is
not
None
:
self
.
weight
=
None
self
.
weight
=
None
if
self
.
topk_method
==
"noaux_tc"
:
if
self
.
e_score_correction_bias
is
not
None
:
self
.
e_score_correction_bias
=
None
self
.
e_score_correction_bias
=
None
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
View file @
f873558a
...
@@ -47,7 +47,7 @@
...
@@ -47,7 +47,7 @@
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp$"
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseekv3.DeepseekV3MoE
class
:
ktransformers.models.modeling_deepseek
_
v3.DeepseekV3MoE
replace
:
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
kwargs
:
...
@@ -55,7 +55,7 @@
...
@@ -55,7 +55,7 @@
prefill_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp$"
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseekv3.DeepseekV3MoE
class
:
ktransformers.models.modeling_deepseek
_
v3.DeepseekV3MoE
replace
:
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
kwargs
:
...
@@ -64,7 +64,7 @@
...
@@ -64,7 +64,7 @@
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.gate$"
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseekv3.
MoEGa
te
class
:
ktransformers.models.modeling_deepseek
_
v3.
DeepseekV3TopkRou
te
r
replace
:
replace
:
class
:
ktransformers.operators.gate.KMoEGate
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
kwargs
:
...
@@ -72,7 +72,7 @@
...
@@ -72,7 +72,7 @@
prefill_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.gate$"
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseekv3.
MoEGa
te
class
:
ktransformers.models.modeling_deepseek
_
v3.
DeepseekV3TopkRou
te
r
replace
:
replace
:
class
:
ktransformers.operators.gate.KMoEGate
# mlp module with custom forward function
class
:
ktransformers.operators.gate.KMoEGate
# mlp module with custom forward function
kwargs
:
kwargs
:
...
...
ktransformers/server/config/config.py
View file @
f873558a
...
@@ -102,7 +102,7 @@ class Config(metaclass=Singleton):
...
@@ -102,7 +102,7 @@ class Config(metaclass=Singleton):
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
max_chunk_size
=
self
.
model
.
get
(
"max_chunk_size"
,
2048
)
self
.
max_chunk_size
=
self
.
model
.
get
(
"max_chunk_size"
,
2048
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
5
00
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
20
00
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
healing
=
self
.
model
.
get
(
"healing"
,
False
)
self
.
healing
=
self
.
model
.
get
(
"healing"
,
False
)
self
.
ban_strings
:
Optional
[
list
]
=
self
.
model
.
get
(
"ban_strings"
,
None
)
self
.
ban_strings
:
Optional
[
list
]
=
self
.
model
.
get
(
"ban_strings"
,
None
)
...
...
ktransformers/util/modeling_rope_utils.py
View file @
f873558a
...
@@ -58,7 +58,8 @@ def _compute_default_rope_parameters(
...
@@ -58,7 +58,8 @@ def _compute_default_rope_parameters(
elif
config
is
not
None
:
elif
config
is
not
None
:
base
=
config
.
rope_theta
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
int
((
config
.
hidden_size
//
config
.
num_attention_heads
)
*
partial_rotary_factor
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
attention_factor
=
1.0
# Unused in this type of RoPE
attention_factor
=
1.0
# Unused in this type of RoPE
...
@@ -143,14 +144,15 @@ def _compute_dynamic_ntk_parameters(
...
@@ -143,14 +144,15 @@ def _compute_dynamic_ntk_parameters(
elif
config
is
not
None
:
elif
config
is
not
None
:
base
=
config
.
rope_theta
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
int
((
config
.
hidden_size
//
config
.
num_attention_heads
)
*
partial_rotary_factor
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
max_position_embeddings
=
config
.
max_position_embeddings
max_position_embeddings
=
config
.
max_position_embeddings
factor
=
config
.
rope_scaling
[
"factor"
]
factor
=
config
.
rope_scaling
[
"factor"
]
attention_factor
=
1.0
# Unused in this type of RoPE
attention_factor
=
1.0
# Unused in this type of RoPE
# seq_len: default to max_position_embeddings, e.g. at init time
# seq_len: default to max_position_embeddings, e.g. at init time
seq_len
=
seq_len
if
seq_len
is
not
None
else
max_position_embeddings
seq_len
=
seq_len
if
seq_len
is
not
None
and
seq_len
>
max_position_embeddings
else
max_position_embeddings
# Compute the inverse frequencies
# Compute the inverse frequencies
base
=
base
*
((
factor
*
seq_len
/
max_position_embeddings
)
-
(
factor
-
1
))
**
(
dim
/
(
dim
-
2
))
base
=
base
*
((
factor
*
seq_len
/
max_position_embeddings
)
-
(
factor
-
1
))
**
(
dim
/
(
dim
-
2
))
...
@@ -185,15 +187,33 @@ def _compute_yarn_parameters(
...
@@ -185,15 +187,33 @@ def _compute_yarn_parameters(
base
=
config
.
rope_theta
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
config
.
qk_rope_head_dim
head_dim
=
getattr
(
config
,
"qk_rope_head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
max_position_embeddings
=
config
.
max_position_embeddings
factor
=
config
.
rope_scaling
[
"factor"
]
factor
=
config
.
rope_scaling
[
"factor"
]
attention_factor
=
config
.
rope_scaling
.
get
(
"attention_factor"
)
mscale
=
config
.
rope_scaling
.
get
(
"mscale"
)
mscale_all_dim
=
config
.
rope_scaling
.
get
(
"mscale_all_dim"
)
# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if
"original_max_position_embeddings"
in
config
.
rope_scaling
:
original_max_position_embeddings
=
config
.
rope_scaling
[
"original_max_position_embeddings"
]
factor
=
config
.
max_position_embeddings
/
original_max_position_embeddings
else
:
original_max_position_embeddings
=
config
.
max_position_embeddings
def
get_mscale
(
scale
,
mscale
=
1
):
if
scale
<=
1
:
return
1.0
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
# Sets the attention factor as suggested in the paper
# Sets the attention factor as suggested in the paper
attention_factor
=
config
.
rope_scaling
.
get
(
"attention_factor"
)
if
attention_factor
is
None
:
if
attention_factor
is
None
:
attention_factor
=
0.1
*
math
.
log
(
factor
)
+
1.0
if
mscale
and
mscale_all_dim
:
attention_factor
=
float
(
get_mscale
(
factor
,
mscale
)
/
get_mscale
(
factor
,
mscale_all_dim
))
else
:
attention_factor
=
get_mscale
(
factor
)
# Optional config options
# Optional config options
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
...
@@ -211,7 +231,7 @@ def _compute_yarn_parameters(
...
@@ -211,7 +231,7 @@ def _compute_yarn_parameters(
high
=
math
.
ceil
(
find_correction_dim
(
high_rot
,
dim
,
base
,
max_position_embeddings
))
high
=
math
.
ceil
(
find_correction_dim
(
high_rot
,
dim
,
base
,
max_position_embeddings
))
return
max
(
low
,
0
),
min
(
high
,
dim
-
1
)
return
max
(
low
,
0
),
min
(
high
,
dim
-
1
)
def
linear_ramp_
mask
(
min
,
max
,
dim
):
def
linear_ramp_
factor
(
min
,
max
,
dim
):
if
min
==
max
:
if
min
==
max
:
max
+=
0.001
# Prevent singularity
max
+=
0.001
# Prevent singularity
...
@@ -219,16 +239,20 @@ def _compute_yarn_parameters(
...
@@ -219,16 +239,20 @@ def _compute_yarn_parameters(
ramp_func
=
torch
.
clamp
(
linear_func
,
0
,
1
)
ramp_func
=
torch
.
clamp
(
linear_func
,
0
,
1
)
return
ramp_func
return
ramp_func
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
# to expand the possible context length. In other words, interpolation = apply scaling factor.
pos_freqs
=
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
().
to
(
device
)
/
dim
)
pos_freqs
=
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
().
to
(
device
)
/
dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
factor
*
pos_freqs
)
inv_freq_interpolation
=
1.0
/
(
factor
*
pos_freqs
)
low
,
high
=
find_correction_range
(
beta_fast
,
beta_slow
,
dim
,
base
,
max_position_embeddings
)
low
,
high
=
find_correction_range
(
beta_fast
,
beta_slow
,
dim
,
base
,
original_
max_position_embeddings
)
# Get n-dimensional rotational scaling corrected for extrapolation
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_mask
=
1
-
linear_ramp_mask
(
low
,
high
,
dim
//
2
).
float
().
to
(
device
)
inv_freq_extrapolation_factor
=
1
-
linear_ramp_factor
(
low
,
high
,
dim
//
2
).
float
().
to
(
device
)
inv_freq
=
inv_freq_interpolation
*
(
1
-
inv_freq_mask
)
+
inv_freq_extrapolation
*
inv_freq_mask
inv_freq
=
(
inv_freq_interpolation
*
(
1
-
inv_freq_extrapolation_factor
)
+
inv_freq_extrapolation
*
inv_freq_extrapolation_factor
)
return
inv_freq
,
attention_factor
return
inv_freq
,
attention_factor
...
@@ -244,7 +268,7 @@ def _compute_longrope_parameters(
...
@@ -244,7 +268,7 @@ def _compute_longrope_parameters(
device (`torch.device`):
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
seq_len (`int`, *optional*):
The current sequence length.
Unused for this type of RoPE.
The current sequence length.
rope_kwargs (`Dict`, *optional*):
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Returns:
...
@@ -261,7 +285,8 @@ def _compute_longrope_parameters(
...
@@ -261,7 +285,8 @@ def _compute_longrope_parameters(
base
=
config
.
rope_theta
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
int
((
config
.
hidden_size
//
config
.
num_attention_heads
)
*
partial_rotary_factor
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
long_factor
=
config
.
rope_scaling
[
"long_factor"
]
long_factor
=
config
.
rope_scaling
[
"long_factor"
]
short_factor
=
config
.
rope_scaling
[
"short_factor"
]
short_factor
=
config
.
rope_scaling
[
"short_factor"
]
factor
=
config
.
rope_scaling
.
get
(
"factor"
)
factor
=
config
.
rope_scaling
.
get
(
"factor"
)
...
@@ -271,22 +296,20 @@ def _compute_longrope_parameters(
...
@@ -271,22 +296,20 @@ def _compute_longrope_parameters(
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
# values to compute the default attention scaling factor, instead of using `factor`.
if
hasattr
(
config
,
"original_max_position_embeddings"
):
if
hasattr
(
config
,
"original_max_position_embeddings"
):
max_position_embeddings
=
config
.
original_max_position_embeddings
original_max_position_embeddings
=
config
.
original_max_position_embeddings
expanded_max_position_embeddings
=
config
.
max_position_embeddings
factor
=
config
.
max_position_embeddings
/
config
.
original_max_position_embeddings
factor
=
expanded_max_position_embeddings
/
max_position_embeddings
else
:
else
:
max_position_embeddings
=
config
.
max_position_embeddings
original_max_position_embeddings
=
config
.
max_position_embeddings
expanded_max_position_embeddings
=
max_position_embeddings
*
factor
# Sets the attention factor as suggested in the paper
# Sets the attention factor as suggested in the paper
if
attention_factor
is
None
:
if
attention_factor
is
None
:
if
factor
<=
1.0
:
if
factor
<=
1.0
:
attention_factor
=
1.0
attention_factor
=
1.0
else
:
else
:
attention_factor
=
math
.
sqrt
(
1
+
math
.
log
(
factor
)
/
math
.
log
(
max_position_embeddings
))
attention_factor
=
math
.
sqrt
(
1
+
math
.
log
(
factor
)
/
math
.
log
(
original_
max_position_embeddings
))
# Compute the inverse frequencies -- scaled based on the target sequence length
# Compute the inverse frequencies -- scaled based on the target sequence length
if
expanded_max_position_embeddings
>
max_position_embeddings
:
if
seq_len
and
seq_len
>
original_
max_position_embeddings
:
ext_factors
=
torch
.
tensor
(
long_factor
,
dtype
=
torch
.
float32
,
device
=
device
)
ext_factors
=
torch
.
tensor
(
long_factor
,
dtype
=
torch
.
float32
,
device
=
device
)
else
:
else
:
ext_factors
=
torch
.
tensor
(
short_factor
,
dtype
=
torch
.
float32
,
device
=
device
)
ext_factors
=
torch
.
tensor
(
short_factor
,
dtype
=
torch
.
float32
,
device
=
device
)
...
@@ -325,19 +348,18 @@ def _compute_llama3_parameters(
...
@@ -325,19 +348,18 @@ def _compute_llama3_parameters(
low_freq_wavelen
=
old_context_len
/
low_freq_factor
low_freq_wavelen
=
old_context_len
/
low_freq_factor
high_freq_wavelen
=
old_context_len
/
high_freq_factor
high_freq_wavelen
=
old_context_len
/
high_freq_factor
new_freqs
=
[]
for
freq
in
inv_freq
:
wavelen
=
2
*
math
.
pi
/
inv_freq
wavelen
=
2
*
math
.
pi
/
freq
# wavelen < high_freq_wavelen: do nothing
if
wavelen
<
high_freq_wavelen
:
# wavelen > low_freq_wavelen: divide by factor
new_freqs
.
append
(
freq
)
inv_freq_llama
=
torch
.
where
(
wavelen
>
low_freq_wavelen
,
inv_freq
/
factor
,
inv_freq
)
elif
wavelen
>
low_freq_wavelen
:
# otherwise: interpolate between the two, using a smooth factor
new_freqs
.
append
(
freq
/
factor
)
smooth_factor
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
high_freq_factor
-
low_freq_factor
)
else
:
smoothed_inv_freq
=
(
1
-
smooth_factor
)
*
inv_freq_llama
/
factor
+
smooth_factor
*
inv_freq_llama
assert
low_freq_wavelen
!=
high_freq_wavelen
is_medium_freq
=
~
(
wavelen
<
high_freq_wavelen
)
*
~
(
wavelen
>
low_freq_wavelen
)
smooth
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
high_freq_factor
-
low_freq_factor
)
inv_freq_llama
=
torch
.
where
(
is_medium_freq
,
smoothed_inv_freq
,
inv_freq_llama
)
new_freqs
.
append
((
1
-
smooth
)
*
freq
/
factor
+
smooth
*
freq
)
inv_freq
=
torch
.
tensor
(
new_freqs
,
dtype
=
inv_freq
.
dtype
,
device
=
inv_freq
.
device
)
return
inv_freq_llama
,
attention_factor
return
inv_freq
,
attention_factor
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
...
@@ -353,12 +375,22 @@ ROPE_INIT_FUNCTIONS = {
...
@@ -353,12 +375,22 @@ ROPE_INIT_FUNCTIONS = {
}
}
def
_check_received_keys
(
rope_type
:
str
,
received_keys
:
set
,
required_keys
:
set
,
optional_keys
:
Optional
[
set
]
=
None
):
def
_check_received_keys
(
rope_type
:
str
,
received_keys
:
set
,
required_keys
:
set
,
optional_keys
:
Optional
[
set
]
=
None
,
ignore_keys
:
Optional
[
set
]
=
None
,
):
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
# BC: "rope_type" was originally "type" -- let's
gracefully handle i
t
# BC: "rope_type" was originally "type" -- let's
check for "rope_type" when "type" is presen
t
if
"rope_type"
not
in
received_keys
and
"type"
in
received_keys
:
if
"type"
in
received_keys
:
received_keys
-=
{
"type"
}
received_keys
-=
{
"type"
}
received_keys
.
add
(
"rope_type"
)
required_keys
.
add
(
"rope_type"
)
# Some models need to store model-specific keys, and we don't want to throw warning at them
if
ignore_keys
is
not
None
:
received_keys
-=
ignore_keys
missing_keys
=
required_keys
-
received_keys
missing_keys
=
required_keys
-
received_keys
if
missing_keys
:
if
missing_keys
:
...
@@ -372,47 +404,54 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set,
...
@@ -372,47 +404,54 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set,
logger
.
warning
(
f
"Unrecognized keys in `rope_scaling` for 'rope_type'='
{
rope_type
}
':
{
unused_keys
}
"
)
logger
.
warning
(
f
"Unrecognized keys in `rope_scaling` for 'rope_type'='
{
rope_type
}
':
{
unused_keys
}
"
)
def
_validate_default_rope_parameters
(
config
:
PretrainedConfig
):
def
_validate_default_rope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
}
required_keys
=
{
"rope_type"
}
received_keys
=
set
(
rope_scaling
.
keys
())
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
ignore_keys
=
ignore_keys
)
def
_validate_linear_scaling_rope_parameters
(
config
:
PretrainedConfig
):
def
_validate_linear_scaling_rope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
}
required_keys
=
{
"rope_type"
,
"factor"
}
received_keys
=
set
(
rope_scaling
.
keys
())
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
def
_validate_dynamic_scaling_rope_parameters
(
config
:
PretrainedConfig
):
def
_validate_dynamic_scaling_rope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
}
required_keys
=
{
"rope_type"
,
"factor"
}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys
=
{
"original_max_position_embeddings"
}
optional_keys
=
{
"original_max_position_embeddings"
}
received_keys
=
set
(
rope_scaling
.
keys
())
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
def
_validate_yarn_parameters
(
config
:
PretrainedConfig
):
def
_validate_yarn_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
}
required_keys
=
{
"rope_type"
,
"factor"
}
optional_keys
=
{
"attention_factor"
,
"beta_fast"
,
"beta_slow"
}
optional_keys
=
{
"attention_factor"
,
"beta_fast"
,
"beta_slow"
,
"original_max_position_embeddings"
,
"mscale"
,
"mscale_all_dim"
,
}
received_keys
=
set
(
rope_scaling
.
keys
())
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
...
@@ -437,17 +476,18 @@ def _validate_yarn_parameters(config: PretrainedConfig):
...
@@ -437,17 +476,18 @@ def _validate_yarn_parameters(config: PretrainedConfig):
)
)
def
_validate_longrope_parameters
(
config
:
PretrainedConfig
):
def
_validate_longrope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"short_factor"
,
"long_factor"
}
required_keys
=
{
"rope_type"
,
"short_factor"
,
"long_factor"
}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys
=
{
"attention_factor"
,
"factor"
,
"original_max_position_embeddings"
}
optional_keys
=
{
"attention_factor"
,
"factor"
,
"original_max_position_embeddings"
}
received_keys
=
set
(
rope_scaling
.
keys
())
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
,
ignore_keys
=
ignore_keys
)
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
int
((
config
.
hidden_size
//
config
.
num_attention_heads
)
*
partial_rotary_factor
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
short_factor
=
rope_scaling
.
get
(
"short_factor"
)
short_factor
=
rope_scaling
.
get
(
"short_factor"
)
if
not
isinstance
(
short_factor
,
list
)
and
all
(
isinstance
(
x
,
(
int
,
float
))
for
x
in
short_factor
):
if
not
isinstance
(
short_factor
,
list
)
and
all
(
isinstance
(
x
,
(
int
,
float
))
for
x
in
short_factor
):
...
@@ -479,18 +519,19 @@ def _validate_longrope_parameters(config: PretrainedConfig):
...
@@ -479,18 +519,19 @@ def _validate_longrope_parameters(config: PretrainedConfig):
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
attention_factor
=
rope_scaling
.
get
(
"attention_factor"
)
attention_factor
=
rope_scaling
.
get
(
"attention_factor"
)
if
attention_factor
is
not
None
and
not
isinstance
(
attention_factor
,
float
)
or
attention_factor
<
0
:
if
attention_factor
is
not
None
:
logger
.
warning
(
if
not
isinstance
(
attention_factor
,
float
)
or
attention_factor
<
0.0
:
f
"`rope_scaling`'s attention_factor field must be a float greater than 0, got
{
attention_factor
}
"
logger
.
warning
(
)
f
"`rope_scaling`'s attention_factor field must be a float greater than 0, got
{
attention_factor
}
"
)
def
_validate_llama3_parameters
(
config
:
PretrainedConfig
):
def
_validate_llama3_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
,
"original_max_position_embeddings"
,
"low_freq_factor"
,
"high_freq_factor"
}
required_keys
=
{
"rope_type"
,
"factor"
,
"original_max_position_embeddings"
,
"low_freq_factor"
,
"high_freq_factor"
}
received_keys
=
set
(
rope_scaling
.
keys
())
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
...
@@ -502,7 +543,7 @@ def _validate_llama3_parameters(config: PretrainedConfig):
...
@@ -502,7 +543,7 @@ def _validate_llama3_parameters(config: PretrainedConfig):
logger
.
warning
(
f
"`rope_scaling`'s low_freq_factor field must be a float, got
{
low_freq_factor
}
"
)
logger
.
warning
(
f
"`rope_scaling`'s low_freq_factor field must be a float, got
{
low_freq_factor
}
"
)
if
high_freq_factor
is
None
or
not
isinstance
(
high_freq_factor
,
float
):
if
high_freq_factor
is
None
or
not
isinstance
(
high_freq_factor
,
float
):
logger
.
warning
(
f
"`rope_scaling`'s high_freq_factor field must be a float, got
{
high_freq_factor
}
"
)
logger
.
warning
(
f
"`rope_scaling`'s high_freq_factor field must be a float, got
{
high_freq_factor
}
"
)
if
high_freq_factor
<
low_freq_factor
:
if
high_freq_factor
<
=
low_freq_factor
:
logger
.
warning
(
logger
.
warning
(
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
f
"
{
high_freq_factor
}
and low_freq_factor=
{
low_freq_factor
}
"
f
"
{
high_freq_factor
}
and low_freq_factor=
{
low_freq_factor
}
"
...
@@ -532,7 +573,7 @@ ROPE_VALIDATION_FUNCTIONS = {
...
@@ -532,7 +573,7 @@ ROPE_VALIDATION_FUNCTIONS = {
}
}
def
rope_config_validation
(
config
:
PretrainedConfig
):
def
rope_config_validation
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
"""
"""
Validate the RoPE config arguments, given a `PretrainedConfig` object
Validate the RoPE config arguments, given a `PretrainedConfig` object
"""
"""
...
@@ -544,8 +585,8 @@ def rope_config_validation(config: PretrainedConfig):
...
@@ -544,8 +585,8 @@ def rope_config_validation(config: PretrainedConfig):
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
"default"
))
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
"default"
))
validation_fn
=
ROPE_VALIDATION_FUNCTIONS
.
get
(
rope_type
)
validation_fn
=
ROPE_VALIDATION_FUNCTIONS
.
get
(
rope_type
)
if
validation_fn
is
not
None
:
if
validation_fn
is
not
None
:
validation_fn
(
config
)
validation_fn
(
config
,
ignore_keys
=
ignore_keys
)
else
:
else
:
logger
.
warning
(
logger
.
warning
(
f
"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='
{
rope_type
}
'"
f
"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='
{
rope_type
}
'"
)
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment