Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
f873558a
Commit
f873558a
authored
Feb 01, 2025
by
Azure
Browse files
update rope calculation; update modeling.py; update gate for moe
parent
5a50b346
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
406 additions
and
416 deletions
+406
-416
ktransformers/configs/config.yaml
ktransformers/configs/config.yaml
+1
-1
ktransformers/local_chat.py
ktransformers/local_chat.py
+2
-2
ktransformers/models/configuration_deepseek_v3.py
ktransformers/models/configuration_deepseek_v3.py
+51
-47
ktransformers/models/custom_cache.py
ktransformers/models/custom_cache.py
+4
-0
ktransformers/models/modeling_deepseek_v3.py
ktransformers/models/modeling_deepseek_v3.py
+230
-290
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+3
-2
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+5
-4
ktransformers/operators/gate.py
ktransformers/operators/gate.py
+3
-4
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
...s/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
+4
-4
ktransformers/server/config/config.py
ktransformers/server/config/config.py
+1
-1
ktransformers/util/modeling_rope_utils.py
ktransformers/util/modeling_rope_utils.py
+102
-61
No files found.
ktransformers/configs/config.yaml
View file @
f873558a
...
@@ -54,4 +54,4 @@ long_context:
...
@@ -54,4 +54,4 @@ long_context:
token_step
:
token_step
:
local_chat
:
local_chat
:
prompt_file
:
"
./ktransformers/p.txt"
prompt_file
:
"
"
\ No newline at end of file
\ No newline at end of file
ktransformers/local_chat.py
View file @
f873558a
...
@@ -15,7 +15,7 @@ from ktransformers.server.args import ArgumentParser
...
@@ -15,7 +15,7 @@ from ktransformers.server.args import ArgumentParser
from
ktransformers.models.modeling_deepseek
import
DeepseekV2ForCausalLM
from
ktransformers.models.modeling_deepseek
import
DeepseekV2ForCausalLM
from
ktransformers.models.modeling_deepseekv3
import
DeepseekV3ForCausalLM
from
ktransformers.models.modeling_deepseek
_
v3
import
DeepseekV3ForCausalLM
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeForCausalLM
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeForCausalLM
from
ktransformers.models.modeling_llama
import
LlamaForCausalLM
from
ktransformers.models.modeling_llama
import
LlamaForCausalLM
from
ktransformers.models.modeling_mixtral
import
MixtralForCausalLM
from
ktransformers.models.modeling_mixtral
import
MixtralForCausalLM
...
@@ -78,7 +78,7 @@ def local_chat():
...
@@ -78,7 +78,7 @@ def local_chat():
else
:
else
:
content
+=
line
+
"
\n
"
content
+=
line
+
"
\n
"
if
content
==
""
:
if
content
==
""
:
if
config
.
prompt_file
==
None
or
config
.
prompt_file
==
""
:
if
not
config
.
prompt_file
:
content
=
"hi"
content
=
"hi"
else
:
else
:
content
=
open
(
config
.
prompt_file
,
"r"
).
read
()
content
=
open
(
config
.
prompt_file
,
"r"
).
read
()
...
...
ktransformers/models/configuration_deepseekv3.py
→
ktransformers/models/configuration_deepseek
_
v3.py
View file @
f873558a
...
@@ -14,19 +14,25 @@
...
@@ -14,19 +14,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
DeepSeekV3 model configuration
"""
"""DeepSeekV3 model configuration"""
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.modeling_rope_utils
import
rope_config_validation
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{}
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{}
class
DeepseekV3Config
(
PretrainedConfig
):
class
DeepseekV3Config
(
PretrainedConfig
):
r
"""
r
"""
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the DeepSeek-V3.
defaults will yield a similar configuration to that of the DeepSeek-V3.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
documentation from [`PretrainedConfig`] for more information.
Args:
Args:
vocab_size (`int`, *optional*, defaults to 129280):
vocab_size (`int`, *optional*, defaults to 129280):
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
...
@@ -39,8 +45,6 @@ class DeepseekV3Config(PretrainedConfig):
...
@@ -39,8 +45,6 @@ class DeepseekV3Config(PretrainedConfig):
Dimension of the MoE representations.
Dimension of the MoE representations.
num_hidden_layers (`int`, *optional*, defaults to 61):
num_hidden_layers (`int`, *optional*, defaults to 61):
Number of hidden layers in the Transformer decoder.
Number of hidden layers in the Transformer decoder.
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
Number of nextn predict layers in the DeepSeekV3 Model.
num_attention_heads (`int`, *optional*, defaults to 128):
num_attention_heads (`int`, *optional*, defaults to 128):
Number of attention heads for each attention layer in the Transformer decoder.
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 128):
num_key_value_heads (`int`, *optional*, defaults to 128):
...
@@ -52,38 +56,35 @@ class DeepseekV3Config(PretrainedConfig):
...
@@ -52,38 +56,35 @@ class DeepseekV3Config(PretrainedConfig):
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
`num_attention_heads`.
n_shared_experts (`int`, *optional*, defaults to 1):
n_shared_experts (`int`, *optional*, defaults to 1):
Number of shared experts
, None means dense model
.
Number of shared experts.
n_routed_experts (`int`, *optional*, defaults to 256):
n_routed_experts (`int`, *optional*, defaults to 256):
Number of routed experts, None means dense model.
Number of routed experts.
ep_size (`<fill_type>`, *optional*, defaults to 1): <fill_docstring>
routed_scaling_factor (`float`, *optional*, defaults to 2.5):
routed_scaling_factor (`float`, *optional*, defaults to 2.5):
Scaling factor or routed experts.
Scaling factor or routed experts.
kv_lora_rank (`<fill_type>`, *optional*, defaults to 512): <fill_docstring>
kv_lora_rank (`int`, *optional*, defaults to 512):
q_lora_rank (`<fill_type>`, *optional*, defaults to 1536): <fill_docstring>
Rank of the LoRA matrices for key and value projections.
qk_rope_head_dim (`<fill_type>`, *optional*, defaults to 64): <fill_docstring>
q_lora_rank (`int`, *optional*, defaults to 1536):
v_head_dim (`<fill_type>`, *optional*, defaults to 128): <fill_docstring>
Rank of the LoRA matrices for query projections.
qk_nope_head_dim (`<fill_type>`, *optional*, defaults to 128): <fill_docstring>
qk_rope_head_dim (`int`, *optional*, defaults to 64):
topk_method (`str`, *optional*, defaults to `"noaux_tc"`):
Dimension of the query/key heads that use rotary position embeddings.
Topk method used in routed gate.
v_head_dim (`int`, *optional*, defaults to 128):
Dimension of the value heads.
qk_nope_head_dim (`int`, *optional*, defaults to 128):
Dimension of the query/key heads that don't use rotary position embeddings.
n_group (`int`, *optional*, defaults to 8):
n_group (`int`, *optional*, defaults to 8):
Number of groups for routed experts.
Number of groups for routed experts.
topk_group (`int`, *optional*, defaults to 4):
topk_group (`int`, *optional*, defaults to 4):
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
num_experts_per_tok (`int`, *optional*, defaults to 8):
num_experts_per_tok (`int`, *optional*, defaults to 8):
Number of selected experts, None means dense model.
Number of selected experts, None means dense model.
moe_layer_freq (`int`, *optional*, defaults to 1):
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
first_k_dense_replace (`int`, *optional*, defaults to 3):
first_k_dense_replace (`int`, *optional*, defaults to 3):
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
\--k dense layers--/
\--k dense layers--/
norm_topk_prob (`bool`, *optional*, defaults to `True`):
norm_topk_prob (`bool`, *optional*, defaults to `True`):
Whether to normalize the weights of the routed experts.
Whether to normalize the weights of the routed experts.
scoring_func (`str`, *optional*, defaults to `"sigmoid"`):
Method of computing expert weights.
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
Auxiliary loss weight coefficient.
Auxiliary loss weight coefficient.
Whether to compute the auxiliary loss for each individual sample.
Whether to compute the auxiliary loss for each individual sample.
seq_aux (`<fill_type>`, *optional*, defaults to `True`): <fill_docstring>
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 4096):
max_position_embeddings (`int`, *optional*, defaults to 4096):
...
@@ -119,46 +120,49 @@ class DeepseekV3Config(PretrainedConfig):
...
@@ -119,46 +120,49 @@ class DeepseekV3Config(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
The dropout ratio for the attention probabilities.
```python
```python
>>> from transformers import DeepseekV3Model, DeepseekV3Config
>>> from transformers import DeepseekV3Model, DeepseekV3Config
>>> # Initializing a Deepseek-V3 style configuration
>>> # Initializing a Deepseek-V3 style configuration
>>> configuration = DeepseekV3Config()
>>> configuration = DeepseekV3Config()
>>> # Accessing the model configuration
>>> # Accessing the model configuration
>>> configuration = model.config
>>> configuration = model.config
```"""
```"""
model_type
=
"deepseek_v3"
model_type
=
"deepseek_v3"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
keys_to_ignore_at_inference
=
[
"past_key_values"
]
# Default tensor parallel plan for base model `DeepseekV3Model`
base_model_tp_plan
=
{
"layers.*.gate_proj"
:
"colwise"
,
"layers.*.up_proj"
:
"colwise"
,
"layers.*.down_proj"
:
"rowwise"
,
}
def
__init__
(
def
__init__
(
self
,
self
,
vocab_size
=
129280
,
vocab_size
=
129280
,
hidden_size
=
7168
,
hidden_size
=
7168
,
intermediate_size
=
18432
,
intermediate_size
=
18432
,
moe_intermediate_size
=
2048
,
moe_intermediate_size
=
2048
,
num_hidden_layers
=
61
,
num_hidden_layers
=
61
,
num_nextn_predict_layers
=
1
,
num_attention_heads
=
128
,
num_attention_heads
=
128
,
num_key_value_heads
=
128
,
num_key_value_heads
=
128
,
n_shared_experts
=
1
,
n_shared_experts
=
1
,
n_routed_experts
=
256
,
n_routed_experts
=
256
,
ep_size
=
1
,
routed_scaling_factor
=
2.5
,
routed_scaling_factor
=
2.5
,
kv_lora_rank
=
512
,
kv_lora_rank
=
512
,
q_lora_rank
=
1536
,
q_lora_rank
=
1536
,
qk_rope_head_dim
=
64
,
qk_rope_head_dim
=
64
,
v_head_dim
=
128
,
v_head_dim
=
128
,
qk_nope_head_dim
=
128
,
qk_nope_head_dim
=
128
,
n_group
=
8
,
topk_method
=
'noaux_tc'
,
topk_group
=
4
,
n_group
=
8
,
num_experts_per_tok
=
8
,
topk_group
=
4
,
first_k_dense_replace
=
3
,
num_experts_per_tok
=
8
,
norm_topk_prob
=
True
,
moe_layer_freq
=
1
,
aux_loss_alpha
=
0.001
,
first_k_dense_replace
=
3
,
norm_topk_prob
=
True
,
scoring_func
=
'sigmoid'
,
aux_loss_alpha
=
0.001
,
seq_aux
=
True
,
hidden_act
=
"silu"
,
hidden_act
=
"silu"
,
max_position_embeddings
=
4096
,
max_position_embeddings
=
4096
,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
...
@@ -173,7 +177,6 @@ class DeepseekV3Config(PretrainedConfig):
...
@@ -173,7 +177,6 @@ class DeepseekV3Config(PretrainedConfig):
rope_scaling
=
None
,
rope_scaling
=
None
,
attention_bias
=
False
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
attention_dropout
=
0.0
,
mlp_bias
=
False
,
**
kwargs
,
**
kwargs
,
):
):
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
...
@@ -182,27 +185,24 @@ class DeepseekV3Config(PretrainedConfig):
...
@@ -182,27 +185,24 @@ class DeepseekV3Config(PretrainedConfig):
self
.
intermediate_size
=
intermediate_size
self
.
intermediate_size
=
intermediate_size
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_nextn_predict_layers
=
num_nextn_predict_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_heads
=
num_attention_heads
self
.
n_shared_experts
=
n_shared_experts
self
.
n_shared_experts
=
n_shared_experts
self
.
n_routed_experts
=
n_routed_experts
self
.
n_routed_experts
=
n_routed_experts
self
.
ep_size
=
ep_size
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
kv_lora_rank
=
kv_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
q_lora_rank
=
q_lora_rank
self
.
q_lora_rank
=
q_lora_rank
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
topk_method
=
topk_method
self
.
q_head_dim
=
qk_nope_head_dim
+
qk_rope_head_dim
self
.
head_dim
=
qk_rope_head_dim
self
.
n_group
=
n_group
self
.
n_group
=
n_group
self
.
topk_group
=
topk_group
self
.
topk_group
=
topk_group
self
.
num_experts_per_tok
=
num_experts_per_tok
self
.
num_experts_per_tok
=
num_experts_per_tok
self
.
moe_layer_freq
=
moe_layer_freq
self
.
first_k_dense_replace
=
first_k_dense_replace
self
.
first_k_dense_replace
=
first_k_dense_replace
self
.
norm_topk_prob
=
norm_topk_prob
self
.
norm_topk_prob
=
norm_topk_prob
self
.
scoring_func
=
scoring_func
self
.
aux_loss_alpha
=
aux_loss_alpha
self
.
aux_loss_alpha
=
aux_loss_alpha
self
.
seq_aux
=
seq_aux
# for backward compatibility
# for backward compatibility
if
num_key_value_heads
is
None
:
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
num_key_value_heads
=
num_attention_heads
...
@@ -217,7 +217,11 @@ class DeepseekV3Config(PretrainedConfig):
...
@@ -217,7 +217,11 @@ class DeepseekV3Config(PretrainedConfig):
self
.
rope_scaling
=
rope_scaling
self
.
rope_scaling
=
rope_scaling
self
.
attention_bias
=
attention_bias
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
attention_dropout
=
attention_dropout
self
.
mlp_bias
=
mlp_bias
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if
self
.
rope_scaling
is
not
None
and
"type"
in
self
.
rope_scaling
:
self
.
rope_scaling
[
"rope_type"
]
=
self
.
rope_scaling
[
"type"
]
rope_config_validation
(
self
)
super
().
__init__
(
super
().
__init__
(
pad_token_id
=
pad_token_id
,
pad_token_id
=
pad_token_id
,
...
...
ktransformers/models/custom_cache.py
View file @
f873558a
...
@@ -135,3 +135,7 @@ class StaticCache(transformers.StaticCache):
...
@@ -135,3 +135,7 @@ class StaticCache(transformers.StaticCache):
# In-place ops prevent breaking the static address
# In-place ops prevent breaking the static address
self
.
key_cache
[
layer_idx
].
zero_
()
self
.
key_cache
[
layer_idx
].
zero_
()
self
.
value_cache
[
layer_idx
].
zero_
()
self
.
value_cache
[
layer_idx
].
zero_
()
def
get_max_cache_shape
(
self
)
->
Tuple
[
int
,
int
,
int
,
int
]:
"""Returns the maximum shape of the cache."""
return
self
.
max_cache_len
\ No newline at end of file
ktransformers/models/modeling_deepseekv3.py
→
ktransformers/models/modeling_deepseek
_
v3.py
View file @
f873558a
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/deepseekv3/modular_deepseekv3.py.
# This file was automatically generated from src/transformers/models/deepseek
_
v3/modular_deepseek
_
v3.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# the file from the modular. If any change should be done, please apply the change to the
# modular_deepseekv3.py file directly. One of our CI enforces this.
# modular_deepseek
_
v3.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import
math
import
math
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
...
@@ -30,7 +28,7 @@ from transformers.utils import (
...
@@ -30,7 +28,7 @@ from transformers.utils import (
replace_return_docstrings
,
replace_return_docstrings
,
)
)
from
transformers.utils.deprecation
import
deprecate_kwarg
from
transformers.utils.deprecation
import
deprecate_kwarg
from
.configuration_deepseekv3
import
DeepseekV3Config
from
.configuration_deepseek
_
v3
import
DeepseekV3Config
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
...
@@ -119,15 +117,15 @@ class DeepseekV3RotaryEmbedding(nn.Module):
...
@@ -119,15 +117,15 @@ class DeepseekV3RotaryEmbedding(nn.Module):
class
DeepseekV3MLP
(
nn
.
Module
):
class
DeepseekV3MLP
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
hidden_size
=
None
,
intermediate_size
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
if
hidden_size
is
None
else
hidden_size
self
.
intermediate_size
=
config
.
moe_
intermediate_size
self
.
intermediate_size
=
config
.
intermediate_size
if
intermediate_size
is
None
else
intermediate_size
# TODO rm hard coding
self
.
gate_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
False
)
# config.mlp_bias)
self
.
gate_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
False
)
self
.
up_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
False
)
# config.mlp_bias)
self
.
up_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
False
)
self
.
down_proj
=
nn
.
Linear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
False
)
# config.mlp_bias)
self
.
down_proj
=
nn
.
Linear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
False
)
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -135,70 +133,46 @@ class DeepseekV3MLP(nn.Module):
...
@@ -135,70 +133,46 @@ class DeepseekV3MLP(nn.Module):
return
down_proj
return
down_proj
class
MoEGa
te
(
nn
.
Module
):
class
DeepseekV3TopkRou
te
r
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
top_k
=
config
.
num_experts_per_tok
self
.
top_k
=
config
.
num_experts_per_tok
self
.
n_routed_experts
=
config
.
n_routed_experts
self
.
n_routed_experts
=
config
.
n_routed_experts
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
scoring_func
=
config
.
scoring_func
self
.
seq_aux
=
config
.
seq_aux
self
.
topk_method
=
config
.
topk_method
self
.
n_group
=
config
.
n_group
self
.
n_group
=
config
.
n_group
self
.
topk_group
=
config
.
topk_group
self
.
topk_group
=
config
.
topk_group
# topk selection algorithm
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
((
self
.
n_routed_experts
,
config
.
hidden_size
)))
self
.
norm_topk_prob
=
config
.
norm_topk_prob
self
.
e_score_correction_bias
=
nn
.
Parameter
(
torch
.
empty
((
self
.
n_routed_experts
)))
self
.
gating_dim
=
config
.
hidden_size
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
((
self
.
n_routed_experts
,
self
.
gating_dim
)))
if
self
.
topk_method
==
"noaux_tc"
:
self
.
e_score_correction_bias
=
nn
.
Parameter
(
torch
.
empty
((
self
.
n_routed_experts
)))
self
.
reset_parameters
()
def
reset_parameters
(
self
)
->
None
:
import
torch.nn.init
as
init
init
.
kaiming_uniform_
(
self
.
weight
,
a
=
math
.
sqrt
(
5
))
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
bsz
,
seq_len
,
h
=
hidden_states
.
shape
batch_size
,
seq_length
=
hidden_states
.
shape
[:
-
1
]
### compute gating score
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
config
.
hidden_size
)
hidden_states
=
hidden_states
.
view
(
-
1
,
h
)
router_logits
=
F
.
linear
(
hidden_states
.
type
(
torch
.
float32
),
self
.
weight
.
type
(
torch
.
float32
))
logits
=
F
.
linear
(
hidden_states
.
type
(
torch
.
float32
),
self
.
weight
.
type
(
torch
.
float32
),
None
)
if
self
.
scoring_func
==
"sigmoid"
:
scores
=
router_logits
.
sigmoid
()
scores
=
logits
.
sigmoid
()
scores_for_choice
=
scores
.
view
(
-
1
,
self
.
n_routed_experts
)
+
self
.
e_score_correction_bias
.
unsqueeze
(
0
)
else
:
group_scores
=
(
raise
NotImplementedError
(
f
"insupportable scoring function for MoE gating:
{
self
.
scoring_func
}
"
)
scores_for_choice
.
view
(
-
1
,
self
.
n_group
,
self
.
n_routed_experts
//
self
.
n_group
)
.
topk
(
2
,
dim
=-
1
)[
0
]
### select top-k experts
.
sum
(
dim
=-
1
)
if
self
.
topk_method
==
"noaux_tc"
:
)
# [n, n_group]
# assert not self.training
group_idx
=
torch
.
topk
(
group_scores
,
k
=
self
.
topk_group
,
dim
=-
1
,
sorted
=
False
)[
1
]
# [n, top_k_group]
scores_for_choice
=
scores
.
view
(
bsz
*
seq_len
,
-
1
)
+
self
.
e_score_correction_bias
.
unsqueeze
(
0
)
group_mask
=
torch
.
zeros_like
(
group_scores
)
# [n, n_group]
group_scores
=
(
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# [n, n_group]
scores_for_choice
.
view
(
bsz
*
seq_len
,
self
.
n_group
,
-
1
).
topk
(
2
,
dim
=-
1
)[
0
].
sum
(
dim
=-
1
)
score_mask
=
(
)
# [n, n_group]
group_mask
.
unsqueeze
(
-
1
)
group_idx
=
torch
.
topk
(
group_scores
,
k
=
self
.
topk_group
,
dim
=-
1
,
sorted
=
False
)[
1
]
# [n, top_k_group]
.
expand
(
batch_size
*
seq_length
,
self
.
n_group
,
self
.
n_routed_experts
//
self
.
n_group
)
group_mask
=
torch
.
zeros_like
(
group_scores
)
# [n, n_group]
.
reshape
(
-
1
,
self
.
n_routed_experts
)
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# [n, n_group]
)
# [n, e]
score_mask
=
(
scores_for_choice
=
scores_for_choice
.
masked_fill
(
~
score_mask
.
bool
(),
0.0
)
# [n, e]
group_mask
.
unsqueeze
(
-
1
)
_
,
topk_indices
=
torch
.
topk
(
scores_for_choice
,
k
=
self
.
top_k
,
dim
=-
1
,
sorted
=
False
)
.
expand
(
bsz
*
seq_len
,
self
.
n_group
,
self
.
n_routed_experts
//
self
.
n_group
)
topk_weights
=
scores
.
gather
(
1
,
topk_indices
)
.
reshape
(
bsz
*
seq_len
,
-
1
)
denominator
=
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
1e-20
)
# [n, e]
topk_weights
/=
denominator
tmp_scores
=
scores_for_choice
.
masked_fill
(
~
score_mask
.
bool
(),
0.0
)
# [n, e]
topk_weights
=
topk_weights
*
self
.
routed_scaling_factor
# must multiply the scaling factor
_
,
topk_idx
=
torch
.
topk
(
tmp_scores
,
k
=
self
.
top_k
,
dim
=-
1
,
sorted
=
False
)
return
topk_indices
,
topk_weights
,
router_logits
topk_weight
=
scores
.
gather
(
1
,
topk_idx
)
else
:
raise
NotImplementedError
(
f
"insupportable TopK function for MoE gating:
{
self
.
topk_method
}
"
)
### norm gate to sum 1
if
self
.
top_k
>
1
and
self
.
norm_topk_prob
:
denominator
=
topk_weight
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
1e-20
topk_weight
=
topk_weight
/
denominator
topk_weight
=
topk_weight
*
self
.
routed_scaling_factor
# must multiply the scaling factor
return
topk_idx
,
topk_weight
class
DeepseekV3MoE
(
nn
.
Module
):
class
DeepseekV3MoE
(
nn
.
Module
):
...
@@ -209,116 +183,75 @@ class DeepseekV3MoE(nn.Module):
...
@@ -209,116 +183,75 @@ class DeepseekV3MoE(nn.Module):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
num_experts_per_tok
=
config
.
num_experts_per_tok
self
.
experts
=
nn
.
ModuleList
(
[
if
hasattr
(
config
,
"ep_size"
)
and
config
.
ep_size
>
1
:
DeepseekV3MLP
(
config
,
intermediate_size
=
config
.
moe_intermediate_size
)
assert
config
.
ep_size
==
dist
.
get_world_size
()
for
_
in
range
(
config
.
n_routed_experts
)
self
.
ep_size
=
config
.
ep_size
]
self
.
experts_per_rank
=
config
.
n_routed_experts
//
config
.
ep_size
)
self
.
ep_rank
=
dist
.
get_rank
()
self
.
gate
=
DeepseekV3TopkRouter
(
config
)
self
.
experts
=
nn
.
ModuleList
(
self
.
shared_experts
=
DeepseekV3MLP
(
config
=
config
,
intermediate_size
=
config
.
moe_intermediate_size
)
[
(
DeepseekV3MLP
(
config
,
intermediate_size
=
config
.
moe_intermediate_size
)
if
i
>=
self
.
ep_rank
*
self
.
experts_per_rank
and
i
<
(
self
.
ep_rank
+
1
)
*
self
.
experts_per_rank
else
None
)
for
i
in
range
(
config
.
n_routed_experts
)
]
)
else
:
self
.
ep_size
=
1
self
.
experts_per_rank
=
config
.
n_routed_experts
self
.
ep_rank
=
0
self
.
experts
=
nn
.
ModuleList
(
[
DeepseekV3MLP
(
config
)
for
i
in
range
(
config
.
n_routed_experts
)
]
)
self
.
gate
=
MoEGate
(
config
)
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
self
.
shared_experts
=
DeepseekV3MLP
(
config
=
config
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
identity
=
hidden_states
residuals
=
hidden_states
orig_shape
=
hidden_states
.
shape
orig_shape
=
hidden_states
.
shape
topk_i
dx
,
topk_weight
=
self
.
gate
(
hidden_states
)
topk_i
ndices
,
topk_weight
s
,
router_logits
=
self
.
gate
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
if
not
self
.
training
:
hidden_states
=
self
.
moe
(
hidden_states
,
topk_indices
,
topk_weights
).
view
(
*
orig_shape
)
y
=
self
.
moe_infer
(
hidden_states
,
topk_idx
,
topk_weight
).
view
(
*
orig_shape
)
hidden_states
=
hidden_states
+
self
.
shared_experts
(
residuals
)
if
self
.
config
.
n_shared_experts
is
not
None
:
return
hidden_states
,
router_logits
y
=
y
+
self
.
shared_experts
(
identity
)
return
y
@
torch
.
no_grad
()
def
moe
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
):
def
moe_infer
(
self
,
x
,
topk_ids
,
topk_weight
):
final_hidden_states
=
torch
.
zeros_like
(
hidden_states
,
dtype
=
topk_weights
.
dtype
)
cnts
=
topk_ids
.
new_zeros
((
topk_ids
.
shape
[
0
],
len
(
self
.
experts
)))
expert_mask
=
torch
.
nn
.
functional
.
one_hot
(
topk_indices
,
num_classes
=
len
(
self
.
experts
))
cnts
.
scatter_
(
1
,
topk_ids
,
1
)
expert_mask
=
expert_mask
.
permute
(
2
,
0
,
1
)
tokens_per_expert
=
cnts
.
sum
(
dim
=
0
)
idxs
=
topk_ids
.
view
(
-
1
).
argsort
()
for
expert_idx
in
range
(
len
(
self
.
experts
)):
sorted_tokens
=
x
[
idxs
//
topk_ids
.
shape
[
1
]]
expert
=
self
.
experts
[
expert_idx
]
sorted_tokens_shape
=
sorted_tokens
.
shape
mask
=
expert_mask
[
expert_idx
]
if
self
.
ep_size
>
1
:
token_indices
,
weight_indices
=
torch
.
where
(
mask
)
tokens_per_ep_rank
=
tokens_per_expert
.
view
(
self
.
ep_size
,
-
1
).
sum
(
dim
=
1
)
tokens_per_expert_group
=
tokens_per_expert
.
new_empty
(
tokens_per_expert
.
shape
[
0
])
if
token_indices
.
numel
()
>
0
:
dist
.
all_to_all_single
(
tokens_per_expert_group
,
tokens_per_expert
)
expert_weights
=
topk_weights
[
token_indices
,
weight_indices
]
output_splits
=
tokens_per_expert_group
.
view
(
self
.
ep_size
,
-
1
).
sum
(
1
).
cpu
().
numpy
().
tolist
()
expert_input
=
hidden_states
[
token_indices
]
gathered_tokens
=
sorted_tokens
.
new_empty
(
expert_output
=
expert
(
expert_input
)
tokens_per_expert_group
.
sum
(
dim
=
0
).
cpu
().
item
(),
sorted_tokens
.
shape
[
1
]
weighted_output
=
expert_output
*
expert_weights
.
unsqueeze
(
-
1
)
)
final_hidden_states
.
index_add_
(
0
,
token_indices
,
weighted_output
)
input_split_sizes
=
tokens_per_ep_rank
.
cpu
().
numpy
().
tolist
()
return
final_hidden_states
.
type
(
hidden_states
.
dtype
)
dist
.
all_to_all
(
list
(
gathered_tokens
.
split
(
output_splits
)),
list
(
sorted_tokens
.
split
(
input_split_sizes
)),
def
rotate_half
(
x
):
)
"""Rotates half the hidden dims of the input."""
tokens_per_expert_post_gather
=
tokens_per_expert_group
.
view
(
self
.
ep_size
,
self
.
experts_per_rank
).
sum
(
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
dim
=
0
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
gatherd_idxs
=
np
.
zeros
(
shape
=
(
gathered_tokens
.
shape
[
0
],),
dtype
=
np
.
int32
)
s
=
0
for
i
,
k
in
enumerate
(
tokens_per_expert_group
.
cpu
().
numpy
()):
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
=
None
,
unsqueeze_dim
=
1
):
gatherd_idxs
[
s
:
s
+
k
]
=
i
%
self
.
experts_per_rank
"""Applies Rotary Position Embedding to the query and key tensors.
s
+=
k
gatherd_idxs
=
gatherd_idxs
.
argsort
()
Args:
sorted_tokens
=
gathered_tokens
[
gatherd_idxs
]
q (`torch.Tensor`): The query tensor.
tokens_per_expert
=
tokens_per_expert_post_gather
k (`torch.Tensor`): The key tensor.
tokens_per_expert
=
tokens_per_expert
.
cpu
().
numpy
()
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
outputs
=
[]
position_ids (`torch.Tensor`, *optional*):
start_idx
=
0
Deprecated and unused.
for
i
,
num_tokens
in
enumerate
(
tokens_per_expert
):
unsqueeze_dim (`int`, *optional*, defaults to 1):
end_idx
=
start_idx
+
num_tokens
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
if
num_tokens
==
0
:
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
continue
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
expert
=
self
.
experts
[
i
+
self
.
ep_rank
*
self
.
experts_per_rank
]
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
tokens_for_this_expert
=
sorted_tokens
[
start_idx
:
end_idx
]
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
expert_out
=
expert
(
tokens_for_this_expert
)
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
outputs
.
append
(
expert_out
)
Returns:
start_idx
=
end_idx
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
outs
=
torch
.
cat
(
outputs
,
dim
=
0
)
if
len
(
outputs
)
else
sorted_tokens
.
new_empty
(
0
)
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
)
if
self
.
ep_size
>
1
:
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
)
new_x
=
torch
.
empty_like
(
outs
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
new_x
[
gatherd_idxs
]
=
outs
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
gathered_tokens
=
new_x
.
new_empty
(
*
sorted_tokens_shape
)
return
q_embed
,
k_embed
dist
.
all_to_all
(
list
(
gathered_tokens
.
split
(
input_split_sizes
)),
list
(
new_x
.
split
(
output_splits
)),
)
outs
=
gathered_tokens
new_x
=
torch
.
empty_like
(
outs
)
new_x
[
idxs
]
=
outs
final_out
=
(
new_x
.
view
(
*
topk_ids
.
shape
,
-
1
)
.
type
(
topk_weight
.
dtype
)
.
mul_
(
topk_weight
.
unsqueeze
(
dim
=-
1
))
.
sum
(
dim
=
1
)
.
type
(
new_x
.
dtype
)
)
return
final_out
def
repeat_kv
(
hidden_states
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
def
repeat_kv
(
hidden_states
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
...
@@ -359,150 +292,94 @@ def eager_attention_forward(
...
@@ -359,150 +292,94 @@ def eager_attention_forward(
return
attn_output
,
attn_weights
return
attn_output
,
attn_weights
# Copied from transformers.models.llama.modeling_llama.rotate_half
def
yarn_get_mscale
(
scale
=
1
,
mscale
=
1
):
def
rotate_half
(
x
):
if
scale
<=
1
:
"""Rotates half the hidden dims of the input."""
return
1.0
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
=
None
,
unsqueeze_dim
=
1
):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
class
DeepseekV3Attention
(
nn
.
Module
):
class
DeepseekV3Attention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
:
DeepseekV3Config
,
layer_idx
:
Optional
[
int
]
=
None
):
def
__init__
(
self
,
config
:
DeepseekV3Config
,
layer_idx
:
int
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
layer_idx
=
layer_idx
self
.
layer_idx
=
layer_idx
if
layer_idx
is
None
:
self
.
num_key_value_groups
=
config
.
num_attention_heads
//
config
.
num_key_value_heads
logger
.
warning_once
(
f
"Instantiating
{
self
.
__class__
.
__name__
}
without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self
.
attention_dropout
=
config
.
attention_dropout
self
.
attention_dropout
=
config
.
attention_dropout
self
.
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
num_heads
=
config
.
num_attention_heads
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
rope_theta
=
config
.
rope_theta
self
.
rope_theta
=
config
.
rope_theta
self
.
q_lora_rank
=
config
.
q_lora_rank
self
.
q_lora_rank
=
config
.
q_lora_rank
self
.
qk_rope_head_dim
=
config
.
qk_rope_head_dim
self
.
qk_rope_head_dim
=
config
.
qk_rope_head_dim
self
.
kv_lora_rank
=
config
.
kv_lora_rank
self
.
kv_lora_rank
=
config
.
kv_lora_rank
self
.
v_head_dim
=
config
.
v_head_dim
self
.
v_head_dim
=
config
.
v_head_dim
self
.
qk_nope_head_dim
=
config
.
qk_nope_head_dim
self
.
qk_nope_head_dim
=
config
.
qk_nope_head_dim
self
.
q_head_dim
=
config
.
qk_
n
ope_head_dim
+
config
.
qk_
r
ope_head_dim
self
.
q_head_dim
=
config
.
qk_
r
ope_head_dim
+
config
.
qk_
n
ope_head_dim
self
.
is_causal
=
True
self
.
is_causal
=
True
self
.
q_a_proj
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
q_lora_rank
,
bias
=
config
.
attention_bias
)
if
self
.
q_lora_rank
is
None
:
self
.
q_a_layernorm
=
DeepseekV3RMSNorm
(
config
.
q_lora_rank
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
q_head_dim
,
bias
=
False
)
self
.
q_b_proj
=
nn
.
Linear
(
config
.
q_lora_rank
,
self
.
num_heads
*
self
.
q_head_dim
,
bias
=
False
)
else
:
self
.
q_a_proj
=
nn
.
Linear
(
self
.
hidden_size
,
config
.
q_lora_rank
,
bias
=
config
.
attention_bias
)
self
.
q_a_layernorm
=
DeepseekV3RMSNorm
(
config
.
q_lora_rank
)
self
.
q_b_proj
=
nn
.
Linear
(
config
.
q_lora_rank
,
self
.
num_heads
*
self
.
q_head_dim
,
bias
=
False
)
self
.
kv_a_proj_with_mqa
=
nn
.
Linear
(
self
.
kv_a_proj_with_mqa
=
nn
.
Linear
(
self
.
hidden_size
,
config
.
hidden_size
,
config
.
kv_lora_rank
+
config
.
qk_rope_head_dim
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
bias
=
config
.
attention_bias
,
bias
=
config
.
attention_bias
,
)
)
self
.
kv_a_layernorm
=
DeepseekV3RMSNorm
(
config
.
kv_lora_rank
)
self
.
kv_a_layernorm
=
DeepseekV3RMSNorm
(
self
.
kv_lora_rank
)
self
.
kv_b_proj
=
nn
.
Linear
(
self
.
kv_b_proj
=
nn
.
Linear
(
config
.
kv_lora_rank
,
self
.
kv_lora_rank
,
self
.
num_heads
*
(
self
.
q_head_dim
-
self
.
qk_rope_head_dim
+
self
.
v_head_dim
),
self
.
num_heads
*
(
self
.
q_head_dim
-
self
.
qk_rope_head_dim
+
self
.
v_head_dim
),
bias
=
False
,
bias
=
False
,
)
)
self
.
o_proj
=
nn
.
Linear
(
self
.
o_proj
=
nn
.
Linear
(
self
.
num_heads
*
self
.
v_head_dim
,
self
.
num_heads
*
self
.
v_head_dim
,
self
.
hidden_size
,
config
.
hidden_size
,
bias
=
config
.
attention_bias
,
bias
=
config
.
attention_bias
,
)
)
self
.
rotary_emb
=
DeepseekV3RotaryEmbedding
(
self
.
scaling
=
self
.
q_head_dim
**
(
-
0.5
)
config
=
self
.
config
,
if
self
.
config
.
rope_scaling
is
not
None
:
)
mscale_all_dim
=
self
.
config
.
rope_scaling
.
get
(
"mscale_all_dim"
,
0
)
scaling_factor
=
self
.
config
.
rope_scaling
[
"factor"
]
if
mscale_all_dim
:
mscale
=
yarn_get_mscale
(
scaling_factor
,
mscale_all_dim
)
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
# TODO apply in DeepSeekV3Model to share accrose layers
self
.
rotary_emb
=
DeepseekV3RotaryEmbedding
(
config
=
config
)
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
position_embeddings
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
attention_mask
:
Optional
[
torch
.
Tensor
],
attention_mask
:
Optional
[
torch
.
Tensor
],
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
# : Unpack[FlashAttentionKwargs],
**
kwargs
# : Unpack[FlashAttentionKwargs],
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
bsz
,
q_len
,
_
=
hidden_states
.
size
()
input_shape
=
hidden_states
.
shape
[:
-
1
]
hidden_shape
=
(
*
input_shape
,
self
.
num_heads
,
-
1
)
if
self
.
q_lora_rank
is
None
:
q_states
=
self
.
q_b_proj
(
self
.
q_a_layernorm
(
self
.
q_a_proj
(
hidden_states
))).
view
(
hidden_shape
).
transpose
(
1
,
2
)
q
=
self
.
q_proj
(
hidden_states
)
q_pass
,
q_rot
=
torch
.
split
(
q_states
,
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
else
:
q
=
self
.
q_b_proj
(
self
.
q_a_layernorm
(
self
.
q_a_proj
(
hidden_states
)))
q
=
q
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
).
transpose
(
1
,
2
)
q_nope
,
q_pe
=
torch
.
split
(
q
,
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
compressed_kv
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)
compressed_kv
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)
compressed_kv
,
k_pe
=
torch
.
split
(
compressed_kv
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
k_pass
,
k_rot
=
torch
.
split
(
compressed_kv
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
).
transpose
(
1
,
2
)
kv
=
(
self
.
kv_b_proj
(
self
.
kv_a_layernorm
(
compressed_kv
))
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
.
transpose
(
1
,
2
)
)
k_nope
,
value_states
=
torch
.
split
(
kv
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k_pass
=
self
.
kv_b_proj
(
self
.
kv_a_layernorm
(
k_pass
)).
view
(
hidden_shape
).
transpose
(
1
,
2
)
kv_seq_len
=
value_states
.
shape
[
-
2
]
k_pass
,
value_states
=
torch
.
split
(
k_pass
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
if
past_key_value
is
not
None
:
if
self
.
layer_idx
is
None
:
raise
ValueError
(
f
"The cache structure has changed since version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
,
position_ids
)
k_rot
=
k_rot
.
view
(
*
input_shape
,
1
,
self
.
qk_rope_head_dim
).
transpose
(
1
,
2
)
query_states
=
k_pe
.
new_empty
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
q_head_dim
)
cos
,
sin
=
position_embeddings
q
uery_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
q_nope
q
_rot
,
k_rot
=
apply_rotary_pos_emb
(
q_rot
,
k_rot
,
cos
,
sin
)
query_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
q_pe
k_rot
=
k_rot
.
expand
(
-
1
,
self
.
num_heads
,
-
1
,
-
1
)
key_states
=
k_pe
.
new_empty
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
q_head_dim
)
query_states
=
torch
.
cat
((
q_pass
,
q_rot
),
dim
=-
1
)
key_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
k_nope
key_states
=
torch
.
cat
((
k_pass
,
k_rot
),
dim
=-
1
)
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
if
self
.
q_head_dim
!=
self
.
v_head_dim
:
if
self
.
config
.
_attn_implementation
==
"flash_attention_2"
and
self
.
q_head_dim
!=
self
.
v_head_dim
:
value_states
=
F
.
pad
(
value_states
,
[
0
,
self
.
q_head_dim
-
self
.
v_head_dim
])
value_states
=
F
.
pad
(
value_states
,
[
0
,
self
.
q_head_dim
-
self
.
v_head_dim
])
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
...
@@ -518,8 +395,11 @@ class DeepseekV3Attention(nn.Module):
...
@@ -518,8 +395,11 @@ class DeepseekV3Attention(nn.Module):
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
)
else
:
else
:
pass
raise
NotImplementedError
(
attention_interface
=
ALL_ATTENTION_FUNCTIONS
[
self
.
config
.
_attn_implementation
]
f
"Attention implementation
{
self
.
config
.
_attn_implementation
}
is not supported. "
"Please use 'eager' or 'sdpa'."
)
# attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output
,
attn_weights
=
attention_interface
(
attn_output
,
attn_weights
=
attention_interface
(
self
,
self
,
...
@@ -531,9 +411,12 @@ class DeepseekV3Attention(nn.Module):
...
@@ -531,9 +411,12 @@ class DeepseekV3Attention(nn.Module):
scaling
=
self
.
scaling
,
scaling
=
self
.
scaling
,
**
kwargs
,
**
kwargs
,
)
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
self
.
o_proj
(
attn_output
)
if
self
.
config
.
_attn_implementation
==
"flash_attention_2"
and
self
.
q_head_dim
!=
self
.
v_head_dim
:
attn_output
=
attn_output
[:,
:,
:,
:
self
.
v_head_dim
]
attn_output
=
attn_output
.
reshape
(
*
input_shape
,
-
1
).
contiguous
()
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
attn_weights
return
attn_output
,
attn_weights
...
@@ -544,15 +427,11 @@ class DeepseekV3DecoderLayer(nn.Module):
...
@@ -544,15 +427,11 @@ class DeepseekV3DecoderLayer(nn.Module):
self
.
self_attn
=
DeepseekV3Attention
(
config
=
config
,
layer_idx
=
layer_idx
)
self
.
self_attn
=
DeepseekV3Attention
(
config
=
config
,
layer_idx
=
layer_idx
)
self
.
mlp
=
(
if
layer_idx
>=
config
.
first_k_dense_replace
:
DeepseekV3MoE
(
config
)
self
.
mlp
=
DeepseekV3MoE
(
config
)
if
(
else
:
config
.
n_routed_experts
is
not
None
self
.
mlp
=
DeepseekV3MLP
(
config
)
and
layer_idx
>=
config
.
first_k_dense_replace
and
layer_idx
%
config
.
moe_layer_freq
==
0
)
else
DeepseekV3MLP
(
config
)
)
self
.
input_layernorm
=
DeepseekV3RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
input_layernorm
=
DeepseekV3RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
DeepseekV3RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
DeepseekV3RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
@@ -563,6 +442,7 @@ class DeepseekV3DecoderLayer(nn.Module):
...
@@ -563,6 +442,7 @@ class DeepseekV3DecoderLayer(nn.Module):
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
output_attentions
:
Optional
[
bool
]
=
False
,
output_router_logits
:
Optional
[
bool
]
=
False
,
use_cache
:
Optional
[
bool
]
=
False
,
use_cache
:
Optional
[
bool
]
=
False
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_embeddings
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
# necessary, but kept here for BC
position_embeddings
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
# necessary, but kept here for BC
...
@@ -590,16 +470,24 @@ class DeepseekV3DecoderLayer(nn.Module):
...
@@ -590,16 +470,24 @@ class DeepseekV3DecoderLayer(nn.Module):
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
if
isinstance
(
hidden_states
,
tuple
):
hidden_states
,
router_logits
=
hidden_states
else
:
router_logits
=
(
torch
.
zeros
((
1
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
),)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
outputs
=
(
hidden_states
,)
outputs
=
(
hidden_states
,)
if
output_attentions
:
if
output_attentions
:
outputs
+=
(
self_attn_weights
,)
outputs
+=
(
self_attn_weights
,)
if
output_router_logits
:
outputs
+=
(
router_logits
,)
return
outputs
return
outputs
DEEPSEEKV3_START_DOCSTRING
=
r
"""
DEEPSEEK
_
V3_START_DOCSTRING
=
r
"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
etc.)
...
@@ -618,7 +506,7 @@ DEEPSEEKV3_START_DOCSTRING = r"""
...
@@ -618,7 +506,7 @@ DEEPSEEKV3_START_DOCSTRING = r"""
@
add_start_docstrings
(
@
add_start_docstrings
(
"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top."
,
"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top."
,
DEEPSEEKV3_START_DOCSTRING
,
DEEPSEEK
_
V3_START_DOCSTRING
,
)
)
class
DeepseekV3PreTrainedModel
(
PreTrainedModel
):
class
DeepseekV3PreTrainedModel
(
PreTrainedModel
):
config_class
=
DeepseekV3Config
config_class
=
DeepseekV3Config
...
@@ -646,7 +534,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
...
@@ -646,7 +534,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
DEEPSEEKV3_INPUTS_DOCSTRING
=
r
"""
DEEPSEEK
_
V3_INPUTS_DOCSTRING
=
r
"""
Args:
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
...
@@ -723,7 +611,7 @@ DEEPSEEKV3_INPUTS_DOCSTRING = r"""
...
@@ -723,7 +611,7 @@ DEEPSEEKV3_INPUTS_DOCSTRING = r"""
@
add_start_docstrings
(
@
add_start_docstrings
(
"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top."
,
"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top."
,
DEEPSEEKV3_START_DOCSTRING
,
DEEPSEEK
_
V3_START_DOCSTRING
,
)
)
class
DeepseekV3Model
(
DeepseekV3PreTrainedModel
):
class
DeepseekV3Model
(
DeepseekV3PreTrainedModel
):
"""
"""
...
@@ -733,7 +621,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
...
@@ -733,7 +621,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
config: DeepseekV3Config
config: DeepseekV3Config
"""
"""
def
__init__
(
self
,
config
:
DeepseekV3Config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
...
@@ -745,6 +633,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
...
@@ -745,6 +633,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
self
.
norm
=
DeepseekV3RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
DeepseekV3RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
rotary_emb
=
DeepseekV3RotaryEmbedding
(
config
=
config
)
self
.
rotary_emb
=
DeepseekV3RotaryEmbedding
(
config
=
config
)
self
.
gradient_checkpointing
=
False
self
.
gradient_checkpointing
=
False
self
.
_register_load_state_dict_pre_hook
(
self
.
load_hook
)
# Initialize weights and apply final processing
# Initialize weights and apply final processing
self
.
post_init
()
self
.
post_init
()
...
@@ -755,7 +644,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
...
@@ -755,7 +644,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
def
set_input_embeddings
(
self
,
value
):
def
set_input_embeddings
(
self
,
value
):
self
.
embed_tokens
=
value
self
.
embed_tokens
=
value
@
add_start_docstrings_to_model_forward
(
DEEPSEEKV3_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
DEEPSEEK
_
V3_INPUTS_DOCSTRING
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
LongTensor
=
None
,
input_ids
:
torch
.
LongTensor
=
None
,
...
@@ -983,6 +872,49 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
...
@@ -983,6 +872,49 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
return
causal_mask
return
causal_mask
def
load_hook
(
self
,
state_dict
,
prefix
,
*
args
):
"""
Weights have to be permuted for correct rope formulation. We can't do this in the weights
as every other framework already uses the `Llama` original function (which is copyrighted btw).
And I am not even sure it's better.... anyways end of my rant
"""
def
permute_for_rope
(
input_tensor
):
"""
When you go from the complex ROPE formulation to sin and cos one, you need
to permute the query and key weights (to avoid doing it on the fly)
"""
n_heads
,
dim1
,
dim2
=
input_tensor
.
shape
[
0
],
input_tensor
.
shape
[
1
],
input_tensor
.
shape
[
2
]
input_tensor
=
input_tensor
.
reshape
(
n_heads
*
dim1
,
dim2
)
input_tensor
=
input_tensor
.
view
(
n_heads
,
dim1
//
2
,
2
,
dim2
)
input_tensor
=
input_tensor
.
transpose
(
1
,
2
).
reshape
(
n_heads
,
dim1
,
dim2
)
return
input_tensor
def
permute_layer_for_rope
(
key
,
num_heads
,
head_dim
,
rope_dim
):
weight
=
state_dict
[
key
]
weight
=
weight
.
view
(
num_heads
,
head_dim
,
-
1
)
weight_rot
=
weight
[:,
-
rope_dim
:]
weight_rot
=
permute_for_rope
(
weight_rot
)
weight
[:,
-
rope_dim
:]
=
weight_rot
weight
=
weight
.
view
(
-
1
,
weight
.
shape
[
-
1
])
state_dict
[
key
]
=
weight
for
k
in
state_dict
:
if
"q_b_proj."
in
k
:
permute_layer_for_rope
(
k
,
num_heads
=
self
.
config
.
num_attention_heads
,
head_dim
=
self
.
config
.
q_head_dim
,
rope_dim
=
self
.
config
.
qk_rope_head_dim
,
)
if
"kv_a_proj_with_mqa."
in
k
:
permute_layer_for_rope
(
k
,
num_heads
=
1
,
head_dim
=
self
.
config
.
kv_lora_rank
+
self
.
config
.
qk_rope_head_dim
,
rope_dim
=
self
.
config
.
qk_rope_head_dim
,
)
# class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
# class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
...
@@ -1019,7 +951,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
...
@@ -1019,7 +951,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
return
self
.
model
return
self
.
model
@
deprecate_kwarg
(
"num_logits_to_keep"
,
version
=
"4.50"
,
new_name
=
"logits_to_keep"
)
@
deprecate_kwarg
(
"num_logits_to_keep"
,
version
=
"4.50"
,
new_name
=
"logits_to_keep"
)
@
add_start_docstrings_to_model_forward
(
DEEPSEEKV3_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
DEEPSEEK
_
V3_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
CausalLMOutputWithPast
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
CausalLMOutputWithPast
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -1058,8 +990,8 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
...
@@ -1058,8 +990,8 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
```python
```python
>>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
>>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
>>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf")
>>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek
_
v3/DeepseekV3-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek
_
v3/DeepseekV3-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> inputs = tokenizer(prompt, return_tensors="pt")
...
@@ -1125,7 +1057,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
...
@@ -1125,7 +1057,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
each row of the batch).
"""
,
"""
,
DEEPSEEKV3_START_DOCSTRING
,
DEEPSEEK
_
V3_START_DOCSTRING
,
)
)
class
DeepseekV3ForSequenceClassification
(
DeepseekV3PreTrainedModel
):
class
DeepseekV3ForSequenceClassification
(
DeepseekV3PreTrainedModel
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
@@ -1143,7 +1075,7 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
...
@@ -1143,7 +1075,7 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
def
set_input_embeddings
(
self
,
value
):
def
set_input_embeddings
(
self
,
value
):
self
.
model
.
embed_tokens
=
value
self
.
model
.
embed_tokens
=
value
@
add_start_docstrings_to_model_forward
(
DEEPSEEKV3_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
DEEPSEEK
_
V3_INPUTS_DOCSTRING
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
...
@@ -1213,4 +1145,12 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
...
@@ -1213,4 +1145,12 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
past_key_values
=
transformer_outputs
.
past_key_values
,
past_key_values
=
transformer_outputs
.
past_key_values
,
hidden_states
=
transformer_outputs
.
hidden_states
,
hidden_states
=
transformer_outputs
.
hidden_states
,
attentions
=
transformer_outputs
.
attentions
,
attentions
=
transformer_outputs
.
attentions
,
)
)
\ No newline at end of file
__all__
=
[
"DeepseekV3PreTrainedModel"
,
"DeepseekV3Model"
,
"DeepseekV3ForCausalLM"
,
"DeepseekV3ForSequenceClassification"
,
]
\ No newline at end of file
ktransformers/operators/attention.py
View file @
f873558a
...
@@ -13,7 +13,8 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
...
@@ -13,7 +13,8 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.models.modeling_llama
import
LlamaRotaryEmbedding
from
ktransformers.models.modeling_llama
import
LlamaRotaryEmbedding
from
ktransformers.models.modeling_deepseek
import
DeepseekV2Attention
,
apply_rotary_pos_emb
from
ktransformers.models.modeling_deepseek
import
DeepseekV2Attention
,
apply_rotary_pos_emb
from
ktransformers.models.modeling_deepseekv3
import
DeepseekV3Attention
,
apply_rotary_pos_emb
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3Attention
from
ktransformers.models.modeling_deepseek_v3
import
apply_rotary_pos_emb
as
apply_rotary_pos_emb_v3
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.util.custom_gguf
import
GGUFLoader
...
@@ -95,7 +96,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
...
@@ -95,7 +96,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
_v3
(
q_pe
,
k_pe
,
cos
,
sin
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
...
...
ktransformers/operators/experts.py
View file @
f873558a
...
@@ -519,7 +519,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
...
@@ -519,7 +519,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
from
ktransformers.models.modeling_deepseek
import
DeepseekV2MoE
from
ktransformers.models.modeling_deepseek
import
DeepseekV2MoE
from
ktransformers.models.modeling_deepseekv3
import
DeepseekV3MoE
from
ktransformers.models.modeling_deepseek
_
v3
import
DeepseekV3MoE
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeSparseMoeBlock
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeSparseMoeBlock
from
ktransformers.models.modeling_mixtral
import
MixtralSparseMoeBlock
from
ktransformers.models.modeling_mixtral
import
MixtralSparseMoeBlock
...
@@ -734,9 +734,10 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
...
@@ -734,9 +734,10 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
identity
=
hidden_states
identity
=
hidden_states
orig_shape
=
hidden_states
.
shape
orig_shape
=
hidden_states
.
shape
sequence_length
=
orig_shape
[
1
]
sequence_length
=
orig_shape
[
1
]
topk_idx
,
topk_weight
=
self
.
gate
(
hidden_states
)
topk_idx
,
topk_weight
,
router_logits
=
self
.
gate
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
# only for generate phase
if
sequence_length
==
1
and
hasattr
(
self
.
experts
.
generate_experts
,
"submit_for_one_decode"
)
and
torch
.
cuda
.
is_current_stream_capturing
():
if
sequence_length
==
1
and
hasattr
(
self
.
experts
.
generate_experts
,
"submit_for_one_decode"
)
and
torch
.
cuda
.
is_current_stream_capturing
():
self
.
experts
.
generate_experts
.
submit_for_one_decode
(
hidden_states
[
0
],
topk_idx
[
0
],
topk_weight
[
0
])
self
.
experts
.
generate_experts
.
submit_for_one_decode
(
hidden_states
[
0
],
topk_idx
[
0
],
topk_weight
[
0
])
if
self
.
config
.
n_shared_experts
is
not
None
:
if
self
.
config
.
n_shared_experts
is
not
None
:
...
@@ -744,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
...
@@ -744,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y
=
self
.
experts
.
generate_experts
.
sync_for_one_decode
().
unsqueeze
(
0
)
y
=
self
.
experts
.
generate_experts
.
sync_for_one_decode
().
unsqueeze
(
0
)
y
+=
y_
y
+=
y_
y
.
resize_
(
*
orig_shape
)
y
.
resize_
(
*
orig_shape
)
return
y
return
y
,
router_logits
if
self
.
config
.
n_shared_experts
is
not
None
:
if
self
.
config
.
n_shared_experts
is
not
None
:
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
...
@@ -767,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
...
@@ -767,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
)
)
if
self
.
config
.
n_shared_experts
is
not
None
:
if
self
.
config
.
n_shared_experts
is
not
None
:
y
+=
y_
y
+=
y_
return
y
return
y
,
router_logits
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
ktransformers/operators/gate.py
View file @
f873558a
...
@@ -16,7 +16,7 @@ from cpuinfer_ext.moe import MOEConfig, MOE
...
@@ -16,7 +16,7 @@ from cpuinfer_ext.moe import MOEConfig, MOE
import
ctypes
import
ctypes
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.models.modeling_deepseekv3
import
MoEGa
te
from
ktransformers.models.modeling_deepseek
_
v3
import
DeepseekV3TopkRou
te
r
from
ktransformers.util.utils
import
InferenceState
from
ktransformers.util.utils
import
InferenceState
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.config.config
import
Config
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
...
@@ -118,11 +118,10 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
...
@@ -118,11 +118,10 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
else
:
else
:
raise
ValueError
(
"Invalid weight type"
)
raise
ValueError
(
"Invalid weight type"
)
self
.
orig_module
.
weight
=
self
.
orig_module
.
weight
.
to
(
device
)
self
.
orig_module
.
weight
=
self
.
orig_module
.
weight
.
to
(
device
)
if
self
.
topk_method
==
"noaux_tc"
:
self
.
orig_module
.
e_score_correction_bias
=
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
)
self
.
orig_module
.
e_score_correction_bias
=
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
)
def
unload
(
self
):
def
unload
(
self
):
if
self
.
weight
is
not
None
:
if
self
.
weight
is
not
None
:
self
.
weight
=
None
self
.
weight
=
None
if
self
.
topk_method
==
"noaux_tc"
:
if
self
.
e_score_correction_bias
is
not
None
:
self
.
e_score_correction_bias
=
None
self
.
e_score_correction_bias
=
None
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
View file @
f873558a
...
@@ -47,7 +47,7 @@
...
@@ -47,7 +47,7 @@
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp$"
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseekv3.DeepseekV3MoE
class
:
ktransformers.models.modeling_deepseek
_
v3.DeepseekV3MoE
replace
:
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
kwargs
:
...
@@ -55,7 +55,7 @@
...
@@ -55,7 +55,7 @@
prefill_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp$"
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseekv3.DeepseekV3MoE
class
:
ktransformers.models.modeling_deepseek
_
v3.DeepseekV3MoE
replace
:
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
kwargs
:
...
@@ -64,7 +64,7 @@
...
@@ -64,7 +64,7 @@
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.gate$"
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseekv3.
MoEGa
te
class
:
ktransformers.models.modeling_deepseek
_
v3.
DeepseekV3TopkRou
te
r
replace
:
replace
:
class
:
ktransformers.operators.gate.KMoEGate
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
kwargs
:
...
@@ -72,7 +72,7 @@
...
@@ -72,7 +72,7 @@
prefill_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.gate$"
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseekv3.
MoEGa
te
class
:
ktransformers.models.modeling_deepseek
_
v3.
DeepseekV3TopkRou
te
r
replace
:
replace
:
class
:
ktransformers.operators.gate.KMoEGate
# mlp module with custom forward function
class
:
ktransformers.operators.gate.KMoEGate
# mlp module with custom forward function
kwargs
:
kwargs
:
...
...
ktransformers/server/config/config.py
View file @
f873558a
...
@@ -102,7 +102,7 @@ class Config(metaclass=Singleton):
...
@@ -102,7 +102,7 @@ class Config(metaclass=Singleton):
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
max_chunk_size
=
self
.
model
.
get
(
"max_chunk_size"
,
2048
)
self
.
max_chunk_size
=
self
.
model
.
get
(
"max_chunk_size"
,
2048
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
5
00
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
20
00
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
healing
=
self
.
model
.
get
(
"healing"
,
False
)
self
.
healing
=
self
.
model
.
get
(
"healing"
,
False
)
self
.
ban_strings
:
Optional
[
list
]
=
self
.
model
.
get
(
"ban_strings"
,
None
)
self
.
ban_strings
:
Optional
[
list
]
=
self
.
model
.
get
(
"ban_strings"
,
None
)
...
...
ktransformers/util/modeling_rope_utils.py
View file @
f873558a
...
@@ -58,7 +58,8 @@ def _compute_default_rope_parameters(
...
@@ -58,7 +58,8 @@ def _compute_default_rope_parameters(
elif
config
is
not
None
:
elif
config
is
not
None
:
base
=
config
.
rope_theta
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
int
((
config
.
hidden_size
//
config
.
num_attention_heads
)
*
partial_rotary_factor
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
attention_factor
=
1.0
# Unused in this type of RoPE
attention_factor
=
1.0
# Unused in this type of RoPE
...
@@ -143,14 +144,15 @@ def _compute_dynamic_ntk_parameters(
...
@@ -143,14 +144,15 @@ def _compute_dynamic_ntk_parameters(
elif
config
is
not
None
:
elif
config
is
not
None
:
base
=
config
.
rope_theta
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
int
((
config
.
hidden_size
//
config
.
num_attention_heads
)
*
partial_rotary_factor
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
max_position_embeddings
=
config
.
max_position_embeddings
max_position_embeddings
=
config
.
max_position_embeddings
factor
=
config
.
rope_scaling
[
"factor"
]
factor
=
config
.
rope_scaling
[
"factor"
]
attention_factor
=
1.0
# Unused in this type of RoPE
attention_factor
=
1.0
# Unused in this type of RoPE
# seq_len: default to max_position_embeddings, e.g. at init time
# seq_len: default to max_position_embeddings, e.g. at init time
seq_len
=
seq_len
if
seq_len
is
not
None
else
max_position_embeddings
seq_len
=
seq_len
if
seq_len
is
not
None
and
seq_len
>
max_position_embeddings
else
max_position_embeddings
# Compute the inverse frequencies
# Compute the inverse frequencies
base
=
base
*
((
factor
*
seq_len
/
max_position_embeddings
)
-
(
factor
-
1
))
**
(
dim
/
(
dim
-
2
))
base
=
base
*
((
factor
*
seq_len
/
max_position_embeddings
)
-
(
factor
-
1
))
**
(
dim
/
(
dim
-
2
))
...
@@ -185,15 +187,33 @@ def _compute_yarn_parameters(
...
@@ -185,15 +187,33 @@ def _compute_yarn_parameters(
base
=
config
.
rope_theta
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
config
.
qk_rope_head_dim
head_dim
=
getattr
(
config
,
"qk_rope_head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
max_position_embeddings
=
config
.
max_position_embeddings
factor
=
config
.
rope_scaling
[
"factor"
]
factor
=
config
.
rope_scaling
[
"factor"
]
attention_factor
=
config
.
rope_scaling
.
get
(
"attention_factor"
)
mscale
=
config
.
rope_scaling
.
get
(
"mscale"
)
mscale_all_dim
=
config
.
rope_scaling
.
get
(
"mscale_all_dim"
)
# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if
"original_max_position_embeddings"
in
config
.
rope_scaling
:
original_max_position_embeddings
=
config
.
rope_scaling
[
"original_max_position_embeddings"
]
factor
=
config
.
max_position_embeddings
/
original_max_position_embeddings
else
:
original_max_position_embeddings
=
config
.
max_position_embeddings
def
get_mscale
(
scale
,
mscale
=
1
):
if
scale
<=
1
:
return
1.0
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
# Sets the attention factor as suggested in the paper
# Sets the attention factor as suggested in the paper
attention_factor
=
config
.
rope_scaling
.
get
(
"attention_factor"
)
if
attention_factor
is
None
:
if
attention_factor
is
None
:
attention_factor
=
0.1
*
math
.
log
(
factor
)
+
1.0
if
mscale
and
mscale_all_dim
:
attention_factor
=
float
(
get_mscale
(
factor
,
mscale
)
/
get_mscale
(
factor
,
mscale_all_dim
))
else
:
attention_factor
=
get_mscale
(
factor
)
# Optional config options
# Optional config options
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
...
@@ -211,7 +231,7 @@ def _compute_yarn_parameters(
...
@@ -211,7 +231,7 @@ def _compute_yarn_parameters(
high
=
math
.
ceil
(
find_correction_dim
(
high_rot
,
dim
,
base
,
max_position_embeddings
))
high
=
math
.
ceil
(
find_correction_dim
(
high_rot
,
dim
,
base
,
max_position_embeddings
))
return
max
(
low
,
0
),
min
(
high
,
dim
-
1
)
return
max
(
low
,
0
),
min
(
high
,
dim
-
1
)
def
linear_ramp_
mask
(
min
,
max
,
dim
):
def
linear_ramp_
factor
(
min
,
max
,
dim
):
if
min
==
max
:
if
min
==
max
:
max
+=
0.001
# Prevent singularity
max
+=
0.001
# Prevent singularity
...
@@ -219,16 +239,20 @@ def _compute_yarn_parameters(
...
@@ -219,16 +239,20 @@ def _compute_yarn_parameters(
ramp_func
=
torch
.
clamp
(
linear_func
,
0
,
1
)
ramp_func
=
torch
.
clamp
(
linear_func
,
0
,
1
)
return
ramp_func
return
ramp_func
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
# to expand the possible context length. In other words, interpolation = apply scaling factor.
pos_freqs
=
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
().
to
(
device
)
/
dim
)
pos_freqs
=
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
().
to
(
device
)
/
dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
factor
*
pos_freqs
)
inv_freq_interpolation
=
1.0
/
(
factor
*
pos_freqs
)
low
,
high
=
find_correction_range
(
beta_fast
,
beta_slow
,
dim
,
base
,
max_position_embeddings
)
low
,
high
=
find_correction_range
(
beta_fast
,
beta_slow
,
dim
,
base
,
original_
max_position_embeddings
)
# Get n-dimensional rotational scaling corrected for extrapolation
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_mask
=
1
-
linear_ramp_mask
(
low
,
high
,
dim
//
2
).
float
().
to
(
device
)
inv_freq_extrapolation_factor
=
1
-
linear_ramp_factor
(
low
,
high
,
dim
//
2
).
float
().
to
(
device
)
inv_freq
=
inv_freq_interpolation
*
(
1
-
inv_freq_mask
)
+
inv_freq_extrapolation
*
inv_freq_mask
inv_freq
=
(
inv_freq_interpolation
*
(
1
-
inv_freq_extrapolation_factor
)
+
inv_freq_extrapolation
*
inv_freq_extrapolation_factor
)
return
inv_freq
,
attention_factor
return
inv_freq
,
attention_factor
...
@@ -244,7 +268,7 @@ def _compute_longrope_parameters(
...
@@ -244,7 +268,7 @@ def _compute_longrope_parameters(
device (`torch.device`):
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
seq_len (`int`, *optional*):
The current sequence length.
Unused for this type of RoPE.
The current sequence length.
rope_kwargs (`Dict`, *optional*):
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Returns:
...
@@ -261,7 +285,8 @@ def _compute_longrope_parameters(
...
@@ -261,7 +285,8 @@ def _compute_longrope_parameters(
base
=
config
.
rope_theta
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
int
((
config
.
hidden_size
//
config
.
num_attention_heads
)
*
partial_rotary_factor
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
long_factor
=
config
.
rope_scaling
[
"long_factor"
]
long_factor
=
config
.
rope_scaling
[
"long_factor"
]
short_factor
=
config
.
rope_scaling
[
"short_factor"
]
short_factor
=
config
.
rope_scaling
[
"short_factor"
]
factor
=
config
.
rope_scaling
.
get
(
"factor"
)
factor
=
config
.
rope_scaling
.
get
(
"factor"
)
...
@@ -271,22 +296,20 @@ def _compute_longrope_parameters(
...
@@ -271,22 +296,20 @@ def _compute_longrope_parameters(
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
# values to compute the default attention scaling factor, instead of using `factor`.
if
hasattr
(
config
,
"original_max_position_embeddings"
):
if
hasattr
(
config
,
"original_max_position_embeddings"
):
max_position_embeddings
=
config
.
original_max_position_embeddings
original_max_position_embeddings
=
config
.
original_max_position_embeddings
expanded_max_position_embeddings
=
config
.
max_position_embeddings
factor
=
config
.
max_position_embeddings
/
config
.
original_max_position_embeddings
factor
=
expanded_max_position_embeddings
/
max_position_embeddings
else
:
else
:
max_position_embeddings
=
config
.
max_position_embeddings
original_max_position_embeddings
=
config
.
max_position_embeddings
expanded_max_position_embeddings
=
max_position_embeddings
*
factor
# Sets the attention factor as suggested in the paper
# Sets the attention factor as suggested in the paper
if
attention_factor
is
None
:
if
attention_factor
is
None
:
if
factor
<=
1.0
:
if
factor
<=
1.0
:
attention_factor
=
1.0
attention_factor
=
1.0
else
:
else
:
attention_factor
=
math
.
sqrt
(
1
+
math
.
log
(
factor
)
/
math
.
log
(
max_position_embeddings
))
attention_factor
=
math
.
sqrt
(
1
+
math
.
log
(
factor
)
/
math
.
log
(
original_
max_position_embeddings
))
# Compute the inverse frequencies -- scaled based on the target sequence length
# Compute the inverse frequencies -- scaled based on the target sequence length
if
expanded_max_position_embeddings
>
max_position_embeddings
:
if
seq_len
and
seq_len
>
original_
max_position_embeddings
:
ext_factors
=
torch
.
tensor
(
long_factor
,
dtype
=
torch
.
float32
,
device
=
device
)
ext_factors
=
torch
.
tensor
(
long_factor
,
dtype
=
torch
.
float32
,
device
=
device
)
else
:
else
:
ext_factors
=
torch
.
tensor
(
short_factor
,
dtype
=
torch
.
float32
,
device
=
device
)
ext_factors
=
torch
.
tensor
(
short_factor
,
dtype
=
torch
.
float32
,
device
=
device
)
...
@@ -325,19 +348,18 @@ def _compute_llama3_parameters(
...
@@ -325,19 +348,18 @@ def _compute_llama3_parameters(
low_freq_wavelen
=
old_context_len
/
low_freq_factor
low_freq_wavelen
=
old_context_len
/
low_freq_factor
high_freq_wavelen
=
old_context_len
/
high_freq_factor
high_freq_wavelen
=
old_context_len
/
high_freq_factor
new_freqs
=
[]
for
freq
in
inv_freq
:
wavelen
=
2
*
math
.
pi
/
inv_freq
wavelen
=
2
*
math
.
pi
/
freq
# wavelen < high_freq_wavelen: do nothing
if
wavelen
<
high_freq_wavelen
:
# wavelen > low_freq_wavelen: divide by factor
new_freqs
.
append
(
freq
)
inv_freq_llama
=
torch
.
where
(
wavelen
>
low_freq_wavelen
,
inv_freq
/
factor
,
inv_freq
)
elif
wavelen
>
low_freq_wavelen
:
# otherwise: interpolate between the two, using a smooth factor
new_freqs
.
append
(
freq
/
factor
)
smooth_factor
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
high_freq_factor
-
low_freq_factor
)
else
:
smoothed_inv_freq
=
(
1
-
smooth_factor
)
*
inv_freq_llama
/
factor
+
smooth_factor
*
inv_freq_llama
assert
low_freq_wavelen
!=
high_freq_wavelen
is_medium_freq
=
~
(
wavelen
<
high_freq_wavelen
)
*
~
(
wavelen
>
low_freq_wavelen
)
smooth
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
high_freq_factor
-
low_freq_factor
)
inv_freq_llama
=
torch
.
where
(
is_medium_freq
,
smoothed_inv_freq
,
inv_freq_llama
)
new_freqs
.
append
((
1
-
smooth
)
*
freq
/
factor
+
smooth
*
freq
)
inv_freq
=
torch
.
tensor
(
new_freqs
,
dtype
=
inv_freq
.
dtype
,
device
=
inv_freq
.
device
)
return
inv_freq_llama
,
attention_factor
return
inv_freq
,
attention_factor
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
...
@@ -353,12 +375,22 @@ ROPE_INIT_FUNCTIONS = {
...
@@ -353,12 +375,22 @@ ROPE_INIT_FUNCTIONS = {
}
}
def
_check_received_keys
(
rope_type
:
str
,
received_keys
:
set
,
required_keys
:
set
,
optional_keys
:
Optional
[
set
]
=
None
):
def
_check_received_keys
(
rope_type
:
str
,
received_keys
:
set
,
required_keys
:
set
,
optional_keys
:
Optional
[
set
]
=
None
,
ignore_keys
:
Optional
[
set
]
=
None
,
):
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
# BC: "rope_type" was originally "type" -- let's
gracefully handle i
t
# BC: "rope_type" was originally "type" -- let's
check for "rope_type" when "type" is presen
t
if
"rope_type"
not
in
received_keys
and
"type"
in
received_keys
:
if
"type"
in
received_keys
:
received_keys
-=
{
"type"
}
received_keys
-=
{
"type"
}
received_keys
.
add
(
"rope_type"
)
required_keys
.
add
(
"rope_type"
)
# Some models need to store model-specific keys, and we don't want to throw warning at them
if
ignore_keys
is
not
None
:
received_keys
-=
ignore_keys
missing_keys
=
required_keys
-
received_keys
missing_keys
=
required_keys
-
received_keys
if
missing_keys
:
if
missing_keys
:
...
@@ -372,47 +404,54 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set,
...
@@ -372,47 +404,54 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set,
logger
.
warning
(
f
"Unrecognized keys in `rope_scaling` for 'rope_type'='
{
rope_type
}
':
{
unused_keys
}
"
)
logger
.
warning
(
f
"Unrecognized keys in `rope_scaling` for 'rope_type'='
{
rope_type
}
':
{
unused_keys
}
"
)
def
_validate_default_rope_parameters
(
config
:
PretrainedConfig
):
def
_validate_default_rope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
}
required_keys
=
{
"rope_type"
}
received_keys
=
set
(
rope_scaling
.
keys
())
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
ignore_keys
=
ignore_keys
)
def
_validate_linear_scaling_rope_parameters
(
config
:
PretrainedConfig
):
def
_validate_linear_scaling_rope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
}
required_keys
=
{
"rope_type"
,
"factor"
}
received_keys
=
set
(
rope_scaling
.
keys
())
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
def
_validate_dynamic_scaling_rope_parameters
(
config
:
PretrainedConfig
):
def
_validate_dynamic_scaling_rope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
}
required_keys
=
{
"rope_type"
,
"factor"
}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys
=
{
"original_max_position_embeddings"
}
optional_keys
=
{
"original_max_position_embeddings"
}
received_keys
=
set
(
rope_scaling
.
keys
())
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
def
_validate_yarn_parameters
(
config
:
PretrainedConfig
):
def
_validate_yarn_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
}
required_keys
=
{
"rope_type"
,
"factor"
}
optional_keys
=
{
"attention_factor"
,
"beta_fast"
,
"beta_slow"
}
optional_keys
=
{
"attention_factor"
,
"beta_fast"
,
"beta_slow"
,
"original_max_position_embeddings"
,
"mscale"
,
"mscale_all_dim"
,
}
received_keys
=
set
(
rope_scaling
.
keys
())
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
...
@@ -437,17 +476,18 @@ def _validate_yarn_parameters(config: PretrainedConfig):
...
@@ -437,17 +476,18 @@ def _validate_yarn_parameters(config: PretrainedConfig):
)
)
def
_validate_longrope_parameters
(
config
:
PretrainedConfig
):
def
_validate_longrope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"short_factor"
,
"long_factor"
}
required_keys
=
{
"rope_type"
,
"short_factor"
,
"long_factor"
}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys
=
{
"attention_factor"
,
"factor"
,
"original_max_position_embeddings"
}
optional_keys
=
{
"attention_factor"
,
"factor"
,
"original_max_position_embeddings"
}
received_keys
=
set
(
rope_scaling
.
keys
())
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
,
ignore_keys
=
ignore_keys
)
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
int
((
config
.
hidden_size
//
config
.
num_attention_heads
)
*
partial_rotary_factor
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
short_factor
=
rope_scaling
.
get
(
"short_factor"
)
short_factor
=
rope_scaling
.
get
(
"short_factor"
)
if
not
isinstance
(
short_factor
,
list
)
and
all
(
isinstance
(
x
,
(
int
,
float
))
for
x
in
short_factor
):
if
not
isinstance
(
short_factor
,
list
)
and
all
(
isinstance
(
x
,
(
int
,
float
))
for
x
in
short_factor
):
...
@@ -479,18 +519,19 @@ def _validate_longrope_parameters(config: PretrainedConfig):
...
@@ -479,18 +519,19 @@ def _validate_longrope_parameters(config: PretrainedConfig):
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
attention_factor
=
rope_scaling
.
get
(
"attention_factor"
)
attention_factor
=
rope_scaling
.
get
(
"attention_factor"
)
if
attention_factor
is
not
None
and
not
isinstance
(
attention_factor
,
float
)
or
attention_factor
<
0
:
if
attention_factor
is
not
None
:
logger
.
warning
(
if
not
isinstance
(
attention_factor
,
float
)
or
attention_factor
<
0.0
:
f
"`rope_scaling`'s attention_factor field must be a float greater than 0, got
{
attention_factor
}
"
logger
.
warning
(
)
f
"`rope_scaling`'s attention_factor field must be a float greater than 0, got
{
attention_factor
}
"
)
def
_validate_llama3_parameters
(
config
:
PretrainedConfig
):
def
_validate_llama3_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
,
"original_max_position_embeddings"
,
"low_freq_factor"
,
"high_freq_factor"
}
required_keys
=
{
"rope_type"
,
"factor"
,
"original_max_position_embeddings"
,
"low_freq_factor"
,
"high_freq_factor"
}
received_keys
=
set
(
rope_scaling
.
keys
())
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
...
@@ -502,7 +543,7 @@ def _validate_llama3_parameters(config: PretrainedConfig):
...
@@ -502,7 +543,7 @@ def _validate_llama3_parameters(config: PretrainedConfig):
logger
.
warning
(
f
"`rope_scaling`'s low_freq_factor field must be a float, got
{
low_freq_factor
}
"
)
logger
.
warning
(
f
"`rope_scaling`'s low_freq_factor field must be a float, got
{
low_freq_factor
}
"
)
if
high_freq_factor
is
None
or
not
isinstance
(
high_freq_factor
,
float
):
if
high_freq_factor
is
None
or
not
isinstance
(
high_freq_factor
,
float
):
logger
.
warning
(
f
"`rope_scaling`'s high_freq_factor field must be a float, got
{
high_freq_factor
}
"
)
logger
.
warning
(
f
"`rope_scaling`'s high_freq_factor field must be a float, got
{
high_freq_factor
}
"
)
if
high_freq_factor
<
low_freq_factor
:
if
high_freq_factor
<
=
low_freq_factor
:
logger
.
warning
(
logger
.
warning
(
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
f
"
{
high_freq_factor
}
and low_freq_factor=
{
low_freq_factor
}
"
f
"
{
high_freq_factor
}
and low_freq_factor=
{
low_freq_factor
}
"
...
@@ -532,7 +573,7 @@ ROPE_VALIDATION_FUNCTIONS = {
...
@@ -532,7 +573,7 @@ ROPE_VALIDATION_FUNCTIONS = {
}
}
def
rope_config_validation
(
config
:
PretrainedConfig
):
def
rope_config_validation
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
"""
"""
Validate the RoPE config arguments, given a `PretrainedConfig` object
Validate the RoPE config arguments, given a `PretrainedConfig` object
"""
"""
...
@@ -544,8 +585,8 @@ def rope_config_validation(config: PretrainedConfig):
...
@@ -544,8 +585,8 @@ def rope_config_validation(config: PretrainedConfig):
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
"default"
))
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
"default"
))
validation_fn
=
ROPE_VALIDATION_FUNCTIONS
.
get
(
rope_type
)
validation_fn
=
ROPE_VALIDATION_FUNCTIONS
.
get
(
rope_type
)
if
validation_fn
is
not
None
:
if
validation_fn
is
not
None
:
validation_fn
(
config
)
validation_fn
(
config
,
ignore_keys
=
ignore_keys
)
else
:
else
:
logger
.
warning
(
logger
.
warning
(
f
"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='
{
rope_type
}
'"
f
"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='
{
rope_type
}
'"
)
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment