Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
f873558a
"vscode:/vscode.git/clone" did not exist on "c4f8f483e705171e8c187ad6acfb9fca8c2bb631"
Commit
f873558a
authored
Feb 01, 2025
by
Azure
Browse files
update rope calculation; update modeling.py; update gate for moe
parent
5a50b346
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
406 additions
and
416 deletions
+406
-416
ktransformers/configs/config.yaml
ktransformers/configs/config.yaml
+1
-1
ktransformers/local_chat.py
ktransformers/local_chat.py
+2
-2
ktransformers/models/configuration_deepseek_v3.py
ktransformers/models/configuration_deepseek_v3.py
+51
-47
ktransformers/models/custom_cache.py
ktransformers/models/custom_cache.py
+4
-0
ktransformers/models/modeling_deepseek_v3.py
ktransformers/models/modeling_deepseek_v3.py
+230
-290
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+3
-2
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+5
-4
ktransformers/operators/gate.py
ktransformers/operators/gate.py
+3
-4
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
...s/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
+4
-4
ktransformers/server/config/config.py
ktransformers/server/config/config.py
+1
-1
ktransformers/util/modeling_rope_utils.py
ktransformers/util/modeling_rope_utils.py
+102
-61
No files found.
ktransformers/configs/config.yaml
View file @
f873558a
...
...
@@ -54,4 +54,4 @@ long_context:
token_step
:
local_chat
:
prompt_file
:
"
./ktransformers/p.txt"
\ No newline at end of file
prompt_file
:
"
"
\ No newline at end of file
ktransformers/local_chat.py
View file @
f873558a
...
...
@@ -15,7 +15,7 @@ from ktransformers.server.args import ArgumentParser
from
ktransformers.models.modeling_deepseek
import
DeepseekV2ForCausalLM
from
ktransformers.models.modeling_deepseekv3
import
DeepseekV3ForCausalLM
from
ktransformers.models.modeling_deepseek
_
v3
import
DeepseekV3ForCausalLM
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeForCausalLM
from
ktransformers.models.modeling_llama
import
LlamaForCausalLM
from
ktransformers.models.modeling_mixtral
import
MixtralForCausalLM
...
...
@@ -78,7 +78,7 @@ def local_chat():
else
:
content
+=
line
+
"
\n
"
if
content
==
""
:
if
config
.
prompt_file
==
None
or
config
.
prompt_file
==
""
:
if
not
config
.
prompt_file
:
content
=
"hi"
else
:
content
=
open
(
config
.
prompt_file
,
"r"
).
read
()
...
...
ktransformers/models/configuration_deepseekv3.py
→
ktransformers/models/configuration_deepseek
_
v3.py
View file @
f873558a
...
...
@@ -14,19 +14,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DeepSeekV3 model configuration
"""
"""DeepSeekV3 model configuration"""
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.modeling_rope_utils
import
rope_config_validation
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{}
class
DeepseekV3Config
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the DeepSeek-V3.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 129280):
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
...
...
@@ -39,8 +45,6 @@ class DeepseekV3Config(PretrainedConfig):
Dimension of the MoE representations.
num_hidden_layers (`int`, *optional*, defaults to 61):
Number of hidden layers in the Transformer decoder.
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
Number of nextn predict layers in the DeepSeekV3 Model.
num_attention_heads (`int`, *optional*, defaults to 128):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 128):
...
...
@@ -52,38 +56,35 @@ class DeepseekV3Config(PretrainedConfig):
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
n_shared_experts (`int`, *optional*, defaults to 1):
Number of shared experts
, None means dense model
.
Number of shared experts.
n_routed_experts (`int`, *optional*, defaults to 256):
Number of routed experts, None means dense model.
ep_size (`<fill_type>`, *optional*, defaults to 1): <fill_docstring>
Number of routed experts.
routed_scaling_factor (`float`, *optional*, defaults to 2.5):
Scaling factor or routed experts.
kv_lora_rank (`<fill_type>`, *optional*, defaults to 512): <fill_docstring>
q_lora_rank (`<fill_type>`, *optional*, defaults to 1536): <fill_docstring>
qk_rope_head_dim (`<fill_type>`, *optional*, defaults to 64): <fill_docstring>
v_head_dim (`<fill_type>`, *optional*, defaults to 128): <fill_docstring>
qk_nope_head_dim (`<fill_type>`, *optional*, defaults to 128): <fill_docstring>
topk_method (`str`, *optional*, defaults to `"noaux_tc"`):
Topk method used in routed gate.
kv_lora_rank (`int`, *optional*, defaults to 512):
Rank of the LoRA matrices for key and value projections.
q_lora_rank (`int`, *optional*, defaults to 1536):
Rank of the LoRA matrices for query projections.
qk_rope_head_dim (`int`, *optional*, defaults to 64):
Dimension of the query/key heads that use rotary position embeddings.
v_head_dim (`int`, *optional*, defaults to 128):
Dimension of the value heads.
qk_nope_head_dim (`int`, *optional*, defaults to 128):
Dimension of the query/key heads that don't use rotary position embeddings.
n_group (`int`, *optional*, defaults to 8):
Number of groups for routed experts.
topk_group (`int`, *optional*, defaults to 4):
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
num_experts_per_tok (`int`, *optional*, defaults to 8):
Number of selected experts, None means dense model.
moe_layer_freq (`int`, *optional*, defaults to 1):
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
first_k_dense_replace (`int`, *optional*, defaults to 3):
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
\--k dense layers--/
norm_topk_prob (`bool`, *optional*, defaults to `True`):
Whether to normalize the weights of the routed experts.
scoring_func (`str`, *optional*, defaults to `"sigmoid"`):
Method of computing expert weights.
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
Auxiliary loss weight coefficient.
Whether to compute the auxiliary loss for each individual sample.
seq_aux (`<fill_type>`, *optional*, defaults to `True`): <fill_docstring>
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 4096):
...
...
@@ -119,46 +120,49 @@ class DeepseekV3Config(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import DeepseekV3Model, DeepseekV3Config
>>> # Initializing a Deepseek-V3 style configuration
>>> configuration = DeepseekV3Config()
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"deepseek_v3"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
# Default tensor parallel plan for base model `DeepseekV3Model`
base_model_tp_plan
=
{
"layers.*.gate_proj"
:
"colwise"
,
"layers.*.up_proj"
:
"colwise"
,
"layers.*.down_proj"
:
"rowwise"
,
}
def
__init__
(
self
,
vocab_size
=
129280
,
hidden_size
=
7168
,
intermediate_size
=
18432
,
moe_intermediate_size
=
2048
,
moe_intermediate_size
=
2048
,
num_hidden_layers
=
61
,
num_nextn_predict_layers
=
1
,
num_attention_heads
=
128
,
num_key_value_heads
=
128
,
n_shared_experts
=
1
,
n_routed_experts
=
256
,
ep_size
=
1
,
routed_scaling_factor
=
2.5
,
kv_lora_rank
=
512
,
q_lora_rank
=
1536
,
qk_rope_head_dim
=
64
,
v_head_dim
=
128
,
qk_nope_head_dim
=
128
,
topk_method
=
'noaux_tc'
,
n_group
=
8
,
topk_group
=
4
,
num_experts_per_tok
=
8
,
moe_layer_freq
=
1
,
first_k_dense_replace
=
3
,
norm_topk_prob
=
True
,
scoring_func
=
'sigmoid'
,
aux_loss_alpha
=
0.001
,
seq_aux
=
True
,
n_shared_experts
=
1
,
n_routed_experts
=
256
,
routed_scaling_factor
=
2.5
,
kv_lora_rank
=
512
,
q_lora_rank
=
1536
,
qk_rope_head_dim
=
64
,
v_head_dim
=
128
,
qk_nope_head_dim
=
128
,
n_group
=
8
,
topk_group
=
4
,
num_experts_per_tok
=
8
,
first_k_dense_replace
=
3
,
norm_topk_prob
=
True
,
aux_loss_alpha
=
0.001
,
hidden_act
=
"silu"
,
max_position_embeddings
=
4096
,
initializer_range
=
0.02
,
...
...
@@ -173,7 +177,6 @@ class DeepseekV3Config(PretrainedConfig):
rope_scaling
=
None
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
mlp_bias
=
False
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
...
...
@@ -182,27 +185,24 @@ class DeepseekV3Config(PretrainedConfig):
self
.
intermediate_size
=
intermediate_size
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_nextn_predict_layers
=
num_nextn_predict_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
n_shared_experts
=
n_shared_experts
self
.
n_routed_experts
=
n_routed_experts
self
.
ep_size
=
ep_size
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
kv_lora_rank
=
kv_lora_rank
self
.
q_lora_rank
=
q_lora_rank
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
topk_method
=
topk_method
self
.
q_head_dim
=
qk_nope_head_dim
+
qk_rope_head_dim
self
.
head_dim
=
qk_rope_head_dim
self
.
n_group
=
n_group
self
.
topk_group
=
topk_group
self
.
num_experts_per_tok
=
num_experts_per_tok
self
.
moe_layer_freq
=
moe_layer_freq
self
.
first_k_dense_replace
=
first_k_dense_replace
self
.
norm_topk_prob
=
norm_topk_prob
self
.
scoring_func
=
scoring_func
self
.
aux_loss_alpha
=
aux_loss_alpha
self
.
seq_aux
=
seq_aux
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
...
...
@@ -217,7 +217,11 @@ class DeepseekV3Config(PretrainedConfig):
self
.
rope_scaling
=
rope_scaling
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
mlp_bias
=
mlp_bias
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if
self
.
rope_scaling
is
not
None
and
"type"
in
self
.
rope_scaling
:
self
.
rope_scaling
[
"rope_type"
]
=
self
.
rope_scaling
[
"type"
]
rope_config_validation
(
self
)
super
().
__init__
(
pad_token_id
=
pad_token_id
,
...
...
ktransformers/models/custom_cache.py
View file @
f873558a
...
...
@@ -135,3 +135,7 @@ class StaticCache(transformers.StaticCache):
# In-place ops prevent breaking the static address
self
.
key_cache
[
layer_idx
].
zero_
()
self
.
value_cache
[
layer_idx
].
zero_
()
def
get_max_cache_shape
(
self
)
->
Tuple
[
int
,
int
,
int
,
int
]:
"""Returns the maximum shape of the cache."""
return
self
.
max_cache_len
\ No newline at end of file
ktransformers/models/modeling_deepseekv3.py
→
ktransformers/models/modeling_deepseek
_
v3.py
View file @
f873558a
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/deepseekv3/modular_deepseekv3.py.
# This file was automatically generated from src/transformers/models/deepseek
_
v3/modular_deepseek
_
v3.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_deepseekv3.py file directly. One of our CI enforces this.
# modular_deepseek
_
v3.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import
math
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
torch
import
nn
...
...
@@ -30,7 +28,7 @@ from transformers.utils import (
replace_return_docstrings
,
)
from
transformers.utils.deprecation
import
deprecate_kwarg
from
.configuration_deepseekv3
import
DeepseekV3Config
from
.configuration_deepseek
_
v3
import
DeepseekV3Config
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -119,15 +117,15 @@ class DeepseekV3RotaryEmbedding(nn.Module):
class
DeepseekV3MLP
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
hidden_size
=
None
,
intermediate_size
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
intermediate_size
=
config
.
moe_
intermediate_size
# TODO rm hard coding
self
.
gate_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
False
)
# config.mlp_bias)
self
.
up_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
False
)
# config.mlp_bias)
self
.
down_proj
=
nn
.
Linear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
False
)
# config.mlp_bias)
self
.
hidden_size
=
config
.
hidden_size
if
hidden_size
is
None
else
hidden_size
self
.
intermediate_size
=
config
.
intermediate_size
if
intermediate_size
is
None
else
intermediate_size
self
.
gate_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
False
)
self
.
up_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
False
)
self
.
down_proj
=
nn
.
Linear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
False
)
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
def
forward
(
self
,
x
):
...
...
@@ -135,70 +133,46 @@ class DeepseekV3MLP(nn.Module):
return
down_proj
class
MoEGa
te
(
nn
.
Module
):
class
DeepseekV3TopkRou
te
r
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
top_k
=
config
.
num_experts_per_tok
self
.
n_routed_experts
=
config
.
n_routed_experts
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
scoring_func
=
config
.
scoring_func
self
.
seq_aux
=
config
.
seq_aux
self
.
topk_method
=
config
.
topk_method
self
.
n_group
=
config
.
n_group
self
.
topk_group
=
config
.
topk_group
# topk selection algorithm
self
.
norm_topk_prob
=
config
.
norm_topk_prob
self
.
gating_dim
=
config
.
hidden_size
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
((
self
.
n_routed_experts
,
self
.
gating_dim
)))
if
self
.
topk_method
==
"noaux_tc"
:
self
.
e_score_correction_bias
=
nn
.
Parameter
(
torch
.
empty
((
self
.
n_routed_experts
)))
self
.
reset_parameters
()
def
reset_parameters
(
self
)
->
None
:
import
torch.nn.init
as
init
init
.
kaiming_uniform_
(
self
.
weight
,
a
=
math
.
sqrt
(
5
))
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
((
self
.
n_routed_experts
,
config
.
hidden_size
)))
self
.
e_score_correction_bias
=
nn
.
Parameter
(
torch
.
empty
((
self
.
n_routed_experts
)))
def
forward
(
self
,
hidden_states
):
bsz
,
seq_len
,
h
=
hidden_states
.
shape
### compute gating score
hidden_states
=
hidden_states
.
view
(
-
1
,
h
)
logits
=
F
.
linear
(
hidden_states
.
type
(
torch
.
float32
),
self
.
weight
.
type
(
torch
.
float32
),
None
)
if
self
.
scoring_func
==
"sigmoid"
:
scores
=
logits
.
sigmoid
()
else
:
raise
NotImplementedError
(
f
"insupportable scoring function for MoE gating:
{
self
.
scoring_func
}
"
)
### select top-k experts
if
self
.
topk_method
==
"noaux_tc"
:
# assert not self.training
scores_for_choice
=
scores
.
view
(
bsz
*
seq_len
,
-
1
)
+
self
.
e_score_correction_bias
.
unsqueeze
(
0
)
group_scores
=
(
scores_for_choice
.
view
(
bsz
*
seq_len
,
self
.
n_group
,
-
1
).
topk
(
2
,
dim
=-
1
)[
0
].
sum
(
dim
=-
1
)
)
# [n, n_group]
group_idx
=
torch
.
topk
(
group_scores
,
k
=
self
.
topk_group
,
dim
=-
1
,
sorted
=
False
)[
1
]
# [n, top_k_group]
group_mask
=
torch
.
zeros_like
(
group_scores
)
# [n, n_group]
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# [n, n_group]
score_mask
=
(
group_mask
.
unsqueeze
(
-
1
)
.
expand
(
bsz
*
seq_len
,
self
.
n_group
,
self
.
n_routed_experts
//
self
.
n_group
)
.
reshape
(
bsz
*
seq_len
,
-
1
)
)
# [n, e]
tmp_scores
=
scores_for_choice
.
masked_fill
(
~
score_mask
.
bool
(),
0.0
)
# [n, e]
_
,
topk_idx
=
torch
.
topk
(
tmp_scores
,
k
=
self
.
top_k
,
dim
=-
1
,
sorted
=
False
)
topk_weight
=
scores
.
gather
(
1
,
topk_idx
)
else
:
raise
NotImplementedError
(
f
"insupportable TopK function for MoE gating:
{
self
.
topk_method
}
"
)
### norm gate to sum 1
if
self
.
top_k
>
1
and
self
.
norm_topk_prob
:
denominator
=
topk_weight
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
1e-20
topk_weight
=
topk_weight
/
denominator
topk_weight
=
topk_weight
*
self
.
routed_scaling_factor
# must multiply the scaling factor
return
topk_idx
,
topk_weight
batch_size
,
seq_length
=
hidden_states
.
shape
[:
-
1
]
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
config
.
hidden_size
)
router_logits
=
F
.
linear
(
hidden_states
.
type
(
torch
.
float32
),
self
.
weight
.
type
(
torch
.
float32
))
scores
=
router_logits
.
sigmoid
()
scores_for_choice
=
scores
.
view
(
-
1
,
self
.
n_routed_experts
)
+
self
.
e_score_correction_bias
.
unsqueeze
(
0
)
group_scores
=
(
scores_for_choice
.
view
(
-
1
,
self
.
n_group
,
self
.
n_routed_experts
//
self
.
n_group
)
.
topk
(
2
,
dim
=-
1
)[
0
]
.
sum
(
dim
=-
1
)
)
# [n, n_group]
group_idx
=
torch
.
topk
(
group_scores
,
k
=
self
.
topk_group
,
dim
=-
1
,
sorted
=
False
)[
1
]
# [n, top_k_group]
group_mask
=
torch
.
zeros_like
(
group_scores
)
# [n, n_group]
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# [n, n_group]
score_mask
=
(
group_mask
.
unsqueeze
(
-
1
)
.
expand
(
batch_size
*
seq_length
,
self
.
n_group
,
self
.
n_routed_experts
//
self
.
n_group
)
.
reshape
(
-
1
,
self
.
n_routed_experts
)
)
# [n, e]
scores_for_choice
=
scores_for_choice
.
masked_fill
(
~
score_mask
.
bool
(),
0.0
)
# [n, e]
_
,
topk_indices
=
torch
.
topk
(
scores_for_choice
,
k
=
self
.
top_k
,
dim
=-
1
,
sorted
=
False
)
topk_weights
=
scores
.
gather
(
1
,
topk_indices
)
denominator
=
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
1e-20
topk_weights
/=
denominator
topk_weights
=
topk_weights
*
self
.
routed_scaling_factor
# must multiply the scaling factor
return
topk_indices
,
topk_weights
,
router_logits
class
DeepseekV3MoE
(
nn
.
Module
):
...
...
@@ -209,116 +183,75 @@ class DeepseekV3MoE(nn.Module):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
num_experts_per_tok
=
config
.
num_experts_per_tok
if
hasattr
(
config
,
"ep_size"
)
and
config
.
ep_size
>
1
:
assert
config
.
ep_size
==
dist
.
get_world_size
()
self
.
ep_size
=
config
.
ep_size
self
.
experts_per_rank
=
config
.
n_routed_experts
//
config
.
ep_size
self
.
ep_rank
=
dist
.
get_rank
()
self
.
experts
=
nn
.
ModuleList
(
[
(
DeepseekV3MLP
(
config
,
intermediate_size
=
config
.
moe_intermediate_size
)
if
i
>=
self
.
ep_rank
*
self
.
experts_per_rank
and
i
<
(
self
.
ep_rank
+
1
)
*
self
.
experts_per_rank
else
None
)
for
i
in
range
(
config
.
n_routed_experts
)
]
)
else
:
self
.
ep_size
=
1
self
.
experts_per_rank
=
config
.
n_routed_experts
self
.
ep_rank
=
0
self
.
experts
=
nn
.
ModuleList
(
[
DeepseekV3MLP
(
config
)
for
i
in
range
(
config
.
n_routed_experts
)
]
)
self
.
gate
=
MoEGate
(
config
)
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
self
.
shared_experts
=
DeepseekV3MLP
(
config
=
config
)
self
.
experts
=
nn
.
ModuleList
(
[
DeepseekV3MLP
(
config
,
intermediate_size
=
config
.
moe_intermediate_size
)
for
_
in
range
(
config
.
n_routed_experts
)
]
)
self
.
gate
=
DeepseekV3TopkRouter
(
config
)
self
.
shared_experts
=
DeepseekV3MLP
(
config
=
config
,
intermediate_size
=
config
.
moe_intermediate_size
)
def
forward
(
self
,
hidden_states
):
identity
=
hidden_states
residuals
=
hidden_states
orig_shape
=
hidden_states
.
shape
topk_i
dx
,
topk_weight
=
self
.
gate
(
hidden_states
)
topk_i
ndices
,
topk_weight
s
,
router_logits
=
self
.
gate
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
if
not
self
.
training
:
y
=
self
.
moe_infer
(
hidden_states
,
topk_idx
,
topk_weight
).
view
(
*
orig_shape
)
if
self
.
config
.
n_shared_experts
is
not
None
:
y
=
y
+
self
.
shared_experts
(
identity
)
return
y
hidden_states
=
self
.
moe
(
hidden_states
,
topk_indices
,
topk_weights
).
view
(
*
orig_shape
)
hidden_states
=
hidden_states
+
self
.
shared_experts
(
residuals
)
return
hidden_states
,
router_logits
@
torch
.
no_grad
()
def
moe_infer
(
self
,
x
,
topk_ids
,
topk_weight
):
cnts
=
topk_ids
.
new_zeros
((
topk_ids
.
shape
[
0
],
len
(
self
.
experts
)))
cnts
.
scatter_
(
1
,
topk_ids
,
1
)
tokens_per_expert
=
cnts
.
sum
(
dim
=
0
)
idxs
=
topk_ids
.
view
(
-
1
).
argsort
()
sorted_tokens
=
x
[
idxs
//
topk_ids
.
shape
[
1
]]
sorted_tokens_shape
=
sorted_tokens
.
shape
if
self
.
ep_size
>
1
:
tokens_per_ep_rank
=
tokens_per_expert
.
view
(
self
.
ep_size
,
-
1
).
sum
(
dim
=
1
)
tokens_per_expert_group
=
tokens_per_expert
.
new_empty
(
tokens_per_expert
.
shape
[
0
])
dist
.
all_to_all_single
(
tokens_per_expert_group
,
tokens_per_expert
)
output_splits
=
tokens_per_expert_group
.
view
(
self
.
ep_size
,
-
1
).
sum
(
1
).
cpu
().
numpy
().
tolist
()
gathered_tokens
=
sorted_tokens
.
new_empty
(
tokens_per_expert_group
.
sum
(
dim
=
0
).
cpu
().
item
(),
sorted_tokens
.
shape
[
1
]
)
input_split_sizes
=
tokens_per_ep_rank
.
cpu
().
numpy
().
tolist
()
dist
.
all_to_all
(
list
(
gathered_tokens
.
split
(
output_splits
)),
list
(
sorted_tokens
.
split
(
input_split_sizes
)),
)
tokens_per_expert_post_gather
=
tokens_per_expert_group
.
view
(
self
.
ep_size
,
self
.
experts_per_rank
).
sum
(
dim
=
0
)
gatherd_idxs
=
np
.
zeros
(
shape
=
(
gathered_tokens
.
shape
[
0
],),
dtype
=
np
.
int32
)
s
=
0
for
i
,
k
in
enumerate
(
tokens_per_expert_group
.
cpu
().
numpy
()):
gatherd_idxs
[
s
:
s
+
k
]
=
i
%
self
.
experts_per_rank
s
+=
k
gatherd_idxs
=
gatherd_idxs
.
argsort
()
sorted_tokens
=
gathered_tokens
[
gatherd_idxs
]
tokens_per_expert
=
tokens_per_expert_post_gather
tokens_per_expert
=
tokens_per_expert
.
cpu
().
numpy
()
outputs
=
[]
start_idx
=
0
for
i
,
num_tokens
in
enumerate
(
tokens_per_expert
):
end_idx
=
start_idx
+
num_tokens
if
num_tokens
==
0
:
continue
expert
=
self
.
experts
[
i
+
self
.
ep_rank
*
self
.
experts_per_rank
]
tokens_for_this_expert
=
sorted_tokens
[
start_idx
:
end_idx
]
expert_out
=
expert
(
tokens_for_this_expert
)
outputs
.
append
(
expert_out
)
start_idx
=
end_idx
outs
=
torch
.
cat
(
outputs
,
dim
=
0
)
if
len
(
outputs
)
else
sorted_tokens
.
new_empty
(
0
)
if
self
.
ep_size
>
1
:
new_x
=
torch
.
empty_like
(
outs
)
new_x
[
gatherd_idxs
]
=
outs
gathered_tokens
=
new_x
.
new_empty
(
*
sorted_tokens_shape
)
dist
.
all_to_all
(
list
(
gathered_tokens
.
split
(
input_split_sizes
)),
list
(
new_x
.
split
(
output_splits
)),
)
outs
=
gathered_tokens
new_x
=
torch
.
empty_like
(
outs
)
new_x
[
idxs
]
=
outs
final_out
=
(
new_x
.
view
(
*
topk_ids
.
shape
,
-
1
)
.
type
(
topk_weight
.
dtype
)
.
mul_
(
topk_weight
.
unsqueeze
(
dim
=-
1
))
.
sum
(
dim
=
1
)
.
type
(
new_x
.
dtype
)
)
return
final_out
def
moe
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
):
final_hidden_states
=
torch
.
zeros_like
(
hidden_states
,
dtype
=
topk_weights
.
dtype
)
expert_mask
=
torch
.
nn
.
functional
.
one_hot
(
topk_indices
,
num_classes
=
len
(
self
.
experts
))
expert_mask
=
expert_mask
.
permute
(
2
,
0
,
1
)
for
expert_idx
in
range
(
len
(
self
.
experts
)):
expert
=
self
.
experts
[
expert_idx
]
mask
=
expert_mask
[
expert_idx
]
token_indices
,
weight_indices
=
torch
.
where
(
mask
)
if
token_indices
.
numel
()
>
0
:
expert_weights
=
topk_weights
[
token_indices
,
weight_indices
]
expert_input
=
hidden_states
[
token_indices
]
expert_output
=
expert
(
expert_input
)
weighted_output
=
expert_output
*
expert_weights
.
unsqueeze
(
-
1
)
final_hidden_states
.
index_add_
(
0
,
token_indices
,
weighted_output
)
return
final_hidden_states
.
type
(
hidden_states
.
dtype
)
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
=
None
,
unsqueeze_dim
=
1
):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
def
repeat_kv
(
hidden_states
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
...
...
@@ -359,150 +292,94 @@ def eager_attention_forward(
return
attn_output
,
attn_weights
# Copied from transformers.models.llama.modeling_llama.rotate_half
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
=
None
,
unsqueeze_dim
=
1
):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
def
yarn_get_mscale
(
scale
=
1
,
mscale
=
1
):
if
scale
<=
1
:
return
1.0
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
class
DeepseekV3Attention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
:
DeepseekV3Config
,
layer_idx
:
Optional
[
int
]
=
None
):
def
__init__
(
self
,
config
:
DeepseekV3Config
,
layer_idx
:
int
):
super
().
__init__
()
self
.
config
=
config
self
.
layer_idx
=
layer_idx
if
layer_idx
is
None
:
logger
.
warning_once
(
f
"Instantiating
{
self
.
__class__
.
__name__
}
without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self
.
num_key_value_groups
=
config
.
num_attention_heads
//
config
.
num_key_value_heads
self
.
attention_dropout
=
config
.
attention_dropout
self
.
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
rope_theta
=
config
.
rope_theta
self
.
q_lora_rank
=
config
.
q_lora_rank
self
.
qk_rope_head_dim
=
config
.
qk_rope_head_dim
self
.
kv_lora_rank
=
config
.
kv_lora_rank
self
.
v_head_dim
=
config
.
v_head_dim
self
.
qk_nope_head_dim
=
config
.
qk_nope_head_dim
self
.
q_head_dim
=
config
.
qk_
n
ope_head_dim
+
config
.
qk_
r
ope_head_dim
self
.
q_head_dim
=
config
.
qk_
r
ope_head_dim
+
config
.
qk_
n
ope_head_dim
self
.
is_causal
=
True
if
self
.
q_lora_rank
is
None
:
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
q_head_dim
,
bias
=
False
)
else
:
self
.
q_a_proj
=
nn
.
Linear
(
self
.
hidden_size
,
config
.
q_lora_rank
,
bias
=
config
.
attention_bias
)
self
.
q_a_layernorm
=
DeepseekV3RMSNorm
(
config
.
q_lora_rank
)
self
.
q_b_proj
=
nn
.
Linear
(
config
.
q_lora_rank
,
self
.
num_heads
*
self
.
q_head_dim
,
bias
=
False
)
self
.
q_a_proj
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
q_lora_rank
,
bias
=
config
.
attention_bias
)
self
.
q_a_layernorm
=
DeepseekV3RMSNorm
(
config
.
q_lora_rank
)
self
.
q_b_proj
=
nn
.
Linear
(
config
.
q_lora_rank
,
self
.
num_heads
*
self
.
q_head_dim
,
bias
=
False
)
self
.
kv_a_proj_with_mqa
=
nn
.
Linear
(
self
.
hidden_size
,
config
.
kv_lora_rank
+
config
.
qk_rope_head_dim
,
config
.
hidden_size
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
bias
=
config
.
attention_bias
,
)
self
.
kv_a_layernorm
=
DeepseekV3RMSNorm
(
config
.
kv_lora_rank
)
self
.
kv_a_layernorm
=
DeepseekV3RMSNorm
(
self
.
kv_lora_rank
)
self
.
kv_b_proj
=
nn
.
Linear
(
config
.
kv_lora_rank
,
self
.
kv_lora_rank
,
self
.
num_heads
*
(
self
.
q_head_dim
-
self
.
qk_rope_head_dim
+
self
.
v_head_dim
),
bias
=
False
,
)
self
.
o_proj
=
nn
.
Linear
(
self
.
num_heads
*
self
.
v_head_dim
,
self
.
hidden_size
,
config
.
hidden_size
,
bias
=
config
.
attention_bias
,
)
self
.
rotary_emb
=
DeepseekV3RotaryEmbedding
(
config
=
self
.
config
,
)
self
.
scaling
=
self
.
q_head_dim
**
(
-
0.5
)
if
self
.
config
.
rope_scaling
is
not
None
:
mscale_all_dim
=
self
.
config
.
rope_scaling
.
get
(
"mscale_all_dim"
,
0
)
scaling_factor
=
self
.
config
.
rope_scaling
[
"factor"
]
if
mscale_all_dim
:
mscale
=
yarn_get_mscale
(
scaling_factor
,
mscale_all_dim
)
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
# TODO apply in DeepSeekV3Model to share accrose layers
self
.
rotary_emb
=
DeepseekV3RotaryEmbedding
(
config
=
config
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
position_embeddings
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
attention_mask
:
Optional
[
torch
.
Tensor
],
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
# : Unpack[FlashAttentionKwargs],
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
bsz
,
q_len
,
_
=
hidden_states
.
size
()
input_shape
=
hidden_states
.
shape
[:
-
1
]
hidden_shape
=
(
*
input_shape
,
self
.
num_heads
,
-
1
)
if
self
.
q_lora_rank
is
None
:
q
=
self
.
q_proj
(
hidden_states
)
else
:
q
=
self
.
q_b_proj
(
self
.
q_a_layernorm
(
self
.
q_a_proj
(
hidden_states
)))
q
=
q
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
).
transpose
(
1
,
2
)
q_nope
,
q_pe
=
torch
.
split
(
q
,
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
q_states
=
self
.
q_b_proj
(
self
.
q_a_layernorm
(
self
.
q_a_proj
(
hidden_states
))).
view
(
hidden_shape
).
transpose
(
1
,
2
)
q_pass
,
q_rot
=
torch
.
split
(
q_states
,
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
compressed_kv
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)
compressed_kv
,
k_pe
=
torch
.
split
(
compressed_kv
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
).
transpose
(
1
,
2
)
kv
=
(
self
.
kv_b_proj
(
self
.
kv_a_layernorm
(
compressed_kv
))
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
.
transpose
(
1
,
2
)
)
k_pass
,
k_rot
=
torch
.
split
(
compressed_kv
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
k_nope
,
value_states
=
torch
.
split
(
kv
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
kv_seq_len
=
value_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
if
self
.
layer_idx
is
None
:
raise
ValueError
(
f
"The cache structure has changed since version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
k_pass
=
self
.
kv_b_proj
(
self
.
kv_a_layernorm
(
k_pass
)).
view
(
hidden_shape
).
transpose
(
1
,
2
)
k_pass
,
value_states
=
torch
.
split
(
k_pass
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
,
position_ids
)
k_rot
=
k_rot
.
view
(
*
input_shape
,
1
,
self
.
qk_rope_head_dim
).
transpose
(
1
,
2
)
query_states
=
k_pe
.
new_empty
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
q_head_dim
)
q
uery_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
q_nope
query_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
q_pe
cos
,
sin
=
position_embeddings
q
_rot
,
k_rot
=
apply_rotary_pos_emb
(
q_rot
,
k_rot
,
cos
,
sin
)
k_rot
=
k_rot
.
expand
(
-
1
,
self
.
num_heads
,
-
1
,
-
1
)
key_states
=
k_pe
.
new_empty
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
q_head_dim
)
key_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
k_nope
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
query_states
=
torch
.
cat
((
q_pass
,
q_rot
),
dim
=-
1
)
key_states
=
torch
.
cat
((
k_pass
,
k_rot
),
dim
=-
1
)
if
self
.
q_head_dim
!=
self
.
v_head_dim
:
if
self
.
config
.
_attn_implementation
==
"flash_attention_2"
and
self
.
q_head_dim
!=
self
.
v_head_dim
:
value_states
=
F
.
pad
(
value_states
,
[
0
,
self
.
q_head_dim
-
self
.
v_head_dim
])
if
past_key_value
is
not
None
:
...
...
@@ -518,8 +395,11 @@ class DeepseekV3Attention(nn.Module):
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else
:
pass
attention_interface
=
ALL_ATTENTION_FUNCTIONS
[
self
.
config
.
_attn_implementation
]
raise
NotImplementedError
(
f
"Attention implementation
{
self
.
config
.
_attn_implementation
}
is not supported. "
"Please use 'eager' or 'sdpa'."
)
# attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output
,
attn_weights
=
attention_interface
(
self
,
...
...
@@ -531,9 +411,12 @@ class DeepseekV3Attention(nn.Module):
scaling
=
self
.
scaling
,
**
kwargs
,
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
self
.
o_proj
(
attn_output
)
if
self
.
config
.
_attn_implementation
==
"flash_attention_2"
and
self
.
q_head_dim
!=
self
.
v_head_dim
:
attn_output
=
attn_output
[:,
:,
:,
:
self
.
v_head_dim
]
attn_output
=
attn_output
.
reshape
(
*
input_shape
,
-
1
).
contiguous
()
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
attn_weights
...
...
@@ -544,15 +427,11 @@ class DeepseekV3DecoderLayer(nn.Module):
self
.
self_attn
=
DeepseekV3Attention
(
config
=
config
,
layer_idx
=
layer_idx
)
self
.
mlp
=
(
DeepseekV3MoE
(
config
)
if
(
config
.
n_routed_experts
is
not
None
and
layer_idx
>=
config
.
first_k_dense_replace
and
layer_idx
%
config
.
moe_layer_freq
==
0
)
else
DeepseekV3MLP
(
config
)
)
if
layer_idx
>=
config
.
first_k_dense_replace
:
self
.
mlp
=
DeepseekV3MoE
(
config
)
else
:
self
.
mlp
=
DeepseekV3MLP
(
config
)
self
.
input_layernorm
=
DeepseekV3RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
DeepseekV3RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -563,6 +442,7 @@ class DeepseekV3DecoderLayer(nn.Module):
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
output_router_logits
:
Optional
[
bool
]
=
False
,
use_cache
:
Optional
[
bool
]
=
False
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_embeddings
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
# necessary, but kept here for BC
...
...
@@ -590,16 +470,24 @@ class DeepseekV3DecoderLayer(nn.Module):
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
if
isinstance
(
hidden_states
,
tuple
):
hidden_states
,
router_logits
=
hidden_states
else
:
router_logits
=
(
torch
.
zeros
((
1
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
),)
hidden_states
=
residual
+
hidden_states
outputs
=
(
hidden_states
,)
if
output_attentions
:
outputs
+=
(
self_attn_weights
,)
if
output_router_logits
:
outputs
+=
(
router_logits
,)
return
outputs
DEEPSEEKV3_START_DOCSTRING
=
r
"""
DEEPSEEK
_
V3_START_DOCSTRING
=
r
"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
...
...
@@ -618,7 +506,7 @@ DEEPSEEKV3_START_DOCSTRING = r"""
@
add_start_docstrings
(
"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top."
,
DEEPSEEKV3_START_DOCSTRING
,
DEEPSEEK
_
V3_START_DOCSTRING
,
)
class
DeepseekV3PreTrainedModel
(
PreTrainedModel
):
config_class
=
DeepseekV3Config
...
...
@@ -646,7 +534,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
DEEPSEEKV3_INPUTS_DOCSTRING
=
r
"""
DEEPSEEK
_
V3_INPUTS_DOCSTRING
=
r
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
...
...
@@ -723,7 +611,7 @@ DEEPSEEKV3_INPUTS_DOCSTRING = r"""
@
add_start_docstrings
(
"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top."
,
DEEPSEEKV3_START_DOCSTRING
,
DEEPSEEK
_
V3_START_DOCSTRING
,
)
class
DeepseekV3Model
(
DeepseekV3PreTrainedModel
):
"""
...
...
@@ -733,7 +621,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
config: DeepseekV3Config
"""
def
__init__
(
self
,
config
:
DeepseekV3Config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
...
...
@@ -745,6 +633,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
self
.
norm
=
DeepseekV3RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
rotary_emb
=
DeepseekV3RotaryEmbedding
(
config
=
config
)
self
.
gradient_checkpointing
=
False
self
.
_register_load_state_dict_pre_hook
(
self
.
load_hook
)
# Initialize weights and apply final processing
self
.
post_init
()
...
...
@@ -755,7 +644,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
def
set_input_embeddings
(
self
,
value
):
self
.
embed_tokens
=
value
@
add_start_docstrings_to_model_forward
(
DEEPSEEKV3_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
DEEPSEEK
_
V3_INPUTS_DOCSTRING
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
...
...
@@ -983,6 +872,49 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
return
causal_mask
def
load_hook
(
self
,
state_dict
,
prefix
,
*
args
):
"""
Weights have to be permuted for correct rope formulation. We can't do this in the weights
as every other framework already uses the `Llama` original function (which is copyrighted btw).
And I am not even sure it's better.... anyways end of my rant
"""
def
permute_for_rope
(
input_tensor
):
"""
When you go from the complex ROPE formulation to sin and cos one, you need
to permute the query and key weights (to avoid doing it on the fly)
"""
n_heads
,
dim1
,
dim2
=
input_tensor
.
shape
[
0
],
input_tensor
.
shape
[
1
],
input_tensor
.
shape
[
2
]
input_tensor
=
input_tensor
.
reshape
(
n_heads
*
dim1
,
dim2
)
input_tensor
=
input_tensor
.
view
(
n_heads
,
dim1
//
2
,
2
,
dim2
)
input_tensor
=
input_tensor
.
transpose
(
1
,
2
).
reshape
(
n_heads
,
dim1
,
dim2
)
return
input_tensor
def
permute_layer_for_rope
(
key
,
num_heads
,
head_dim
,
rope_dim
):
weight
=
state_dict
[
key
]
weight
=
weight
.
view
(
num_heads
,
head_dim
,
-
1
)
weight_rot
=
weight
[:,
-
rope_dim
:]
weight_rot
=
permute_for_rope
(
weight_rot
)
weight
[:,
-
rope_dim
:]
=
weight_rot
weight
=
weight
.
view
(
-
1
,
weight
.
shape
[
-
1
])
state_dict
[
key
]
=
weight
for
k
in
state_dict
:
if
"q_b_proj."
in
k
:
permute_layer_for_rope
(
k
,
num_heads
=
self
.
config
.
num_attention_heads
,
head_dim
=
self
.
config
.
q_head_dim
,
rope_dim
=
self
.
config
.
qk_rope_head_dim
,
)
if
"kv_a_proj_with_mqa."
in
k
:
permute_layer_for_rope
(
k
,
num_heads
=
1
,
head_dim
=
self
.
config
.
kv_lora_rank
+
self
.
config
.
qk_rope_head_dim
,
rope_dim
=
self
.
config
.
qk_rope_head_dim
,
)
# class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
...
...
@@ -1019,7 +951,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
return
self
.
model
@
deprecate_kwarg
(
"num_logits_to_keep"
,
version
=
"4.50"
,
new_name
=
"logits_to_keep"
)
@
add_start_docstrings_to_model_forward
(
DEEPSEEKV3_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
DEEPSEEK
_
V3_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
CausalLMOutputWithPast
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
self
,
...
...
@@ -1058,8 +990,8 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
```python
>>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
>>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf")
>>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek
_
v3/DeepseekV3-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek
_
v3/DeepseekV3-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
...
...
@@ -1125,7 +1057,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
"""
,
DEEPSEEKV3_START_DOCSTRING
,
DEEPSEEK
_
V3_START_DOCSTRING
,
)
class
DeepseekV3ForSequenceClassification
(
DeepseekV3PreTrainedModel
):
def
__init__
(
self
,
config
):
...
...
@@ -1143,7 +1075,7 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
def
set_input_embeddings
(
self
,
value
):
self
.
model
.
embed_tokens
=
value
@
add_start_docstrings_to_model_forward
(
DEEPSEEKV3_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
DEEPSEEK
_
V3_INPUTS_DOCSTRING
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
...
...
@@ -1213,4 +1145,12 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
past_key_values
=
transformer_outputs
.
past_key_values
,
hidden_states
=
transformer_outputs
.
hidden_states
,
attentions
=
transformer_outputs
.
attentions
,
)
\ No newline at end of file
)
__all__
=
[
"DeepseekV3PreTrainedModel"
,
"DeepseekV3Model"
,
"DeepseekV3ForCausalLM"
,
"DeepseekV3ForSequenceClassification"
,
]
\ No newline at end of file
ktransformers/operators/attention.py
View file @
f873558a
...
...
@@ -13,7 +13,8 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.models.modeling_llama
import
LlamaRotaryEmbedding
from
ktransformers.models.modeling_deepseek
import
DeepseekV2Attention
,
apply_rotary_pos_emb
from
ktransformers.models.modeling_deepseekv3
import
DeepseekV3Attention
,
apply_rotary_pos_emb
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3Attention
from
ktransformers.models.modeling_deepseek_v3
import
apply_rotary_pos_emb
as
apply_rotary_pos_emb_v3
from
typing
import
Optional
,
Tuple
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
...
...
@@ -95,7 +96,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
_v3
(
q_pe
,
k_pe
,
cos
,
sin
)
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
...
...
ktransformers/operators/experts.py
View file @
f873558a
...
...
@@ -519,7 +519,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
from
ktransformers.models.modeling_deepseek
import
DeepseekV2MoE
from
ktransformers.models.modeling_deepseekv3
import
DeepseekV3MoE
from
ktransformers.models.modeling_deepseek
_
v3
import
DeepseekV3MoE
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeSparseMoeBlock
from
ktransformers.models.modeling_mixtral
import
MixtralSparseMoeBlock
...
...
@@ -734,9 +734,10 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
identity
=
hidden_states
orig_shape
=
hidden_states
.
shape
sequence_length
=
orig_shape
[
1
]
topk_idx
,
topk_weight
=
self
.
gate
(
hidden_states
)
topk_idx
,
topk_weight
,
router_logits
=
self
.
gate
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
# only for generate phase
if
sequence_length
==
1
and
hasattr
(
self
.
experts
.
generate_experts
,
"submit_for_one_decode"
)
and
torch
.
cuda
.
is_current_stream_capturing
():
self
.
experts
.
generate_experts
.
submit_for_one_decode
(
hidden_states
[
0
],
topk_idx
[
0
],
topk_weight
[
0
])
if
self
.
config
.
n_shared_experts
is
not
None
:
...
...
@@ -744,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y
=
self
.
experts
.
generate_experts
.
sync_for_one_decode
().
unsqueeze
(
0
)
y
+=
y_
y
.
resize_
(
*
orig_shape
)
return
y
return
y
,
router_logits
if
self
.
config
.
n_shared_experts
is
not
None
:
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
...
...
@@ -767,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
)
if
self
.
config
.
n_shared_experts
is
not
None
:
y
+=
y_
return
y
return
y
,
router_logits
@
torch
.
no_grad
()
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
ktransformers/operators/gate.py
View file @
f873558a
...
...
@@ -16,7 +16,7 @@ from cpuinfer_ext.moe import MOEConfig, MOE
import
ctypes
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.models.modeling_deepseekv3
import
MoEGa
te
from
ktransformers.models.modeling_deepseek
_
v3
import
DeepseekV3TopkRou
te
r
from
ktransformers.util.utils
import
InferenceState
from
ktransformers.server.config.config
import
Config
from
transformers.activations
import
ACT2FN
...
...
@@ -118,11 +118,10 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
else
:
raise
ValueError
(
"Invalid weight type"
)
self
.
orig_module
.
weight
=
self
.
orig_module
.
weight
.
to
(
device
)
if
self
.
topk_method
==
"noaux_tc"
:
self
.
orig_module
.
e_score_correction_bias
=
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
)
self
.
orig_module
.
e_score_correction_bias
=
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
)
def
unload
(
self
):
if
self
.
weight
is
not
None
:
self
.
weight
=
None
if
self
.
topk_method
==
"noaux_tc"
:
if
self
.
e_score_correction_bias
is
not
None
:
self
.
e_score_correction_bias
=
None
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
View file @
f873558a
...
...
@@ -47,7 +47,7 @@
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseekv3.DeepseekV3MoE
class
:
ktransformers.models.modeling_deepseek
_
v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
...
...
@@ -55,7 +55,7 @@
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseekv3.DeepseekV3MoE
class
:
ktransformers.models.modeling_deepseek
_
v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
...
...
@@ -64,7 +64,7 @@
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseekv3.
MoEGa
te
class
:
ktransformers.models.modeling_deepseek
_
v3.
DeepseekV3TopkRou
te
r
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
...
...
@@ -72,7 +72,7 @@
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseekv3.
MoEGa
te
class
:
ktransformers.models.modeling_deepseek
_
v3.
DeepseekV3TopkRou
te
r
replace
:
class
:
ktransformers.operators.gate.KMoEGate
# mlp module with custom forward function
kwargs
:
...
...
ktransformers/server/config/config.py
View file @
f873558a
...
...
@@ -102,7 +102,7 @@ class Config(metaclass=Singleton):
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
max_chunk_size
=
self
.
model
.
get
(
"max_chunk_size"
,
2048
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
5
00
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
20
00
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
healing
=
self
.
model
.
get
(
"healing"
,
False
)
self
.
ban_strings
:
Optional
[
list
]
=
self
.
model
.
get
(
"ban_strings"
,
None
)
...
...
ktransformers/util/modeling_rope_utils.py
View file @
f873558a
...
...
@@ -58,7 +58,8 @@ def _compute_default_rope_parameters(
elif
config
is
not
None
:
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
int
((
config
.
hidden_size
//
config
.
num_attention_heads
)
*
partial_rotary_factor
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
attention_factor
=
1.0
# Unused in this type of RoPE
...
...
@@ -143,14 +144,15 @@ def _compute_dynamic_ntk_parameters(
elif
config
is
not
None
:
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
int
((
config
.
hidden_size
//
config
.
num_attention_heads
)
*
partial_rotary_factor
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
max_position_embeddings
=
config
.
max_position_embeddings
factor
=
config
.
rope_scaling
[
"factor"
]
attention_factor
=
1.0
# Unused in this type of RoPE
# seq_len: default to max_position_embeddings, e.g. at init time
seq_len
=
seq_len
if
seq_len
is
not
None
else
max_position_embeddings
seq_len
=
seq_len
if
seq_len
is
not
None
and
seq_len
>
max_position_embeddings
else
max_position_embeddings
# Compute the inverse frequencies
base
=
base
*
((
factor
*
seq_len
/
max_position_embeddings
)
-
(
factor
-
1
))
**
(
dim
/
(
dim
-
2
))
...
...
@@ -185,15 +187,33 @@ def _compute_yarn_parameters(
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
config
.
qk_rope_head_dim
max_position_embeddings
=
config
.
max_position_embeddings
head_dim
=
getattr
(
config
,
"qk_rope_head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
factor
=
config
.
rope_scaling
[
"factor"
]
attention_factor
=
config
.
rope_scaling
.
get
(
"attention_factor"
)
mscale
=
config
.
rope_scaling
.
get
(
"mscale"
)
mscale_all_dim
=
config
.
rope_scaling
.
get
(
"mscale_all_dim"
)
# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if
"original_max_position_embeddings"
in
config
.
rope_scaling
:
original_max_position_embeddings
=
config
.
rope_scaling
[
"original_max_position_embeddings"
]
factor
=
config
.
max_position_embeddings
/
original_max_position_embeddings
else
:
original_max_position_embeddings
=
config
.
max_position_embeddings
def
get_mscale
(
scale
,
mscale
=
1
):
if
scale
<=
1
:
return
1.0
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
# Sets the attention factor as suggested in the paper
attention_factor
=
config
.
rope_scaling
.
get
(
"attention_factor"
)
if
attention_factor
is
None
:
attention_factor
=
0.1
*
math
.
log
(
factor
)
+
1.0
if
mscale
and
mscale_all_dim
:
attention_factor
=
float
(
get_mscale
(
factor
,
mscale
)
/
get_mscale
(
factor
,
mscale_all_dim
))
else
:
attention_factor
=
get_mscale
(
factor
)
# Optional config options
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
...
...
@@ -211,7 +231,7 @@ def _compute_yarn_parameters(
high
=
math
.
ceil
(
find_correction_dim
(
high_rot
,
dim
,
base
,
max_position_embeddings
))
return
max
(
low
,
0
),
min
(
high
,
dim
-
1
)
def
linear_ramp_
mask
(
min
,
max
,
dim
):
def
linear_ramp_
factor
(
min
,
max
,
dim
):
if
min
==
max
:
max
+=
0.001
# Prevent singularity
...
...
@@ -219,16 +239,20 @@ def _compute_yarn_parameters(
ramp_func
=
torch
.
clamp
(
linear_func
,
0
,
1
)
return
ramp_func
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
# to expand the possible context length. In other words, interpolation = apply scaling factor.
pos_freqs
=
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
().
to
(
device
)
/
dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
factor
*
pos_freqs
)
low
,
high
=
find_correction_range
(
beta_fast
,
beta_slow
,
dim
,
base
,
max_position_embeddings
)
low
,
high
=
find_correction_range
(
beta_fast
,
beta_slow
,
dim
,
base
,
original_
max_position_embeddings
)
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_mask
=
1
-
linear_ramp_mask
(
low
,
high
,
dim
//
2
).
float
().
to
(
device
)
inv_freq
=
inv_freq_interpolation
*
(
1
-
inv_freq_mask
)
+
inv_freq_extrapolation
*
inv_freq_mask
inv_freq_extrapolation_factor
=
1
-
linear_ramp_factor
(
low
,
high
,
dim
//
2
).
float
().
to
(
device
)
inv_freq
=
(
inv_freq_interpolation
*
(
1
-
inv_freq_extrapolation_factor
)
+
inv_freq_extrapolation
*
inv_freq_extrapolation_factor
)
return
inv_freq
,
attention_factor
...
...
@@ -244,7 +268,7 @@ def _compute_longrope_parameters(
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length.
Unused for this type of RoPE.
The current sequence length.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
...
...
@@ -261,7 +285,8 @@ def _compute_longrope_parameters(
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
int
((
config
.
hidden_size
//
config
.
num_attention_heads
)
*
partial_rotary_factor
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
long_factor
=
config
.
rope_scaling
[
"long_factor"
]
short_factor
=
config
.
rope_scaling
[
"short_factor"
]
factor
=
config
.
rope_scaling
.
get
(
"factor"
)
...
...
@@ -271,22 +296,20 @@ def _compute_longrope_parameters(
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if
hasattr
(
config
,
"original_max_position_embeddings"
):
max_position_embeddings
=
config
.
original_max_position_embeddings
expanded_max_position_embeddings
=
config
.
max_position_embeddings
factor
=
expanded_max_position_embeddings
/
max_position_embeddings
original_max_position_embeddings
=
config
.
original_max_position_embeddings
factor
=
config
.
max_position_embeddings
/
config
.
original_max_position_embeddings
else
:
max_position_embeddings
=
config
.
max_position_embeddings
expanded_max_position_embeddings
=
max_position_embeddings
*
factor
original_max_position_embeddings
=
config
.
max_position_embeddings
# Sets the attention factor as suggested in the paper
if
attention_factor
is
None
:
if
factor
<=
1.0
:
attention_factor
=
1.0
else
:
attention_factor
=
math
.
sqrt
(
1
+
math
.
log
(
factor
)
/
math
.
log
(
max_position_embeddings
))
attention_factor
=
math
.
sqrt
(
1
+
math
.
log
(
factor
)
/
math
.
log
(
original_
max_position_embeddings
))
# Compute the inverse frequencies -- scaled based on the target sequence length
if
expanded_max_position_embeddings
>
max_position_embeddings
:
if
seq_len
and
seq_len
>
original_
max_position_embeddings
:
ext_factors
=
torch
.
tensor
(
long_factor
,
dtype
=
torch
.
float32
,
device
=
device
)
else
:
ext_factors
=
torch
.
tensor
(
short_factor
,
dtype
=
torch
.
float32
,
device
=
device
)
...
...
@@ -325,19 +348,18 @@ def _compute_llama3_parameters(
low_freq_wavelen
=
old_context_len
/
low_freq_factor
high_freq_wavelen
=
old_context_len
/
high_freq_factor
new_freqs
=
[]
for
freq
in
inv_freq
:
wavelen
=
2
*
math
.
pi
/
freq
if
wavelen
<
high_freq_wavelen
:
new_freqs
.
append
(
freq
)
elif
wavelen
>
low_freq_wavelen
:
new_freqs
.
append
(
freq
/
factor
)
else
:
assert
low_freq_wavelen
!=
high_freq_wavelen
smooth
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
high_freq_factor
-
low_freq_factor
)
new_freqs
.
append
((
1
-
smooth
)
*
freq
/
factor
+
smooth
*
freq
)
inv_freq
=
torch
.
tensor
(
new_freqs
,
dtype
=
inv_freq
.
dtype
,
device
=
inv_freq
.
device
)
return
inv_freq
,
attention_factor
wavelen
=
2
*
math
.
pi
/
inv_freq
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
inv_freq_llama
=
torch
.
where
(
wavelen
>
low_freq_wavelen
,
inv_freq
/
factor
,
inv_freq
)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
high_freq_factor
-
low_freq_factor
)
smoothed_inv_freq
=
(
1
-
smooth_factor
)
*
inv_freq_llama
/
factor
+
smooth_factor
*
inv_freq_llama
is_medium_freq
=
~
(
wavelen
<
high_freq_wavelen
)
*
~
(
wavelen
>
low_freq_wavelen
)
inv_freq_llama
=
torch
.
where
(
is_medium_freq
,
smoothed_inv_freq
,
inv_freq_llama
)
return
inv_freq_llama
,
attention_factor
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
...
...
@@ -353,12 +375,22 @@ ROPE_INIT_FUNCTIONS = {
}
def
_check_received_keys
(
rope_type
:
str
,
received_keys
:
set
,
required_keys
:
set
,
optional_keys
:
Optional
[
set
]
=
None
):
def
_check_received_keys
(
rope_type
:
str
,
received_keys
:
set
,
required_keys
:
set
,
optional_keys
:
Optional
[
set
]
=
None
,
ignore_keys
:
Optional
[
set
]
=
None
,
):
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
# BC: "rope_type" was originally "type" -- let's
gracefully handle i
t
if
"rope_type"
not
in
received_keys
and
"type"
in
received_keys
:
# BC: "rope_type" was originally "type" -- let's
check for "rope_type" when "type" is presen
t
if
"type"
in
received_keys
:
received_keys
-=
{
"type"
}
received_keys
.
add
(
"rope_type"
)
required_keys
.
add
(
"rope_type"
)
# Some models need to store model-specific keys, and we don't want to throw warning at them
if
ignore_keys
is
not
None
:
received_keys
-=
ignore_keys
missing_keys
=
required_keys
-
received_keys
if
missing_keys
:
...
...
@@ -372,47 +404,54 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set,
logger
.
warning
(
f
"Unrecognized keys in `rope_scaling` for 'rope_type'='
{
rope_type
}
':
{
unused_keys
}
"
)
def
_validate_default_rope_parameters
(
config
:
PretrainedConfig
):
def
_validate_default_rope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
}
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
ignore_keys
=
ignore_keys
)
def
_validate_linear_scaling_rope_parameters
(
config
:
PretrainedConfig
):
def
_validate_linear_scaling_rope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
}
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
def
_validate_dynamic_scaling_rope_parameters
(
config
:
PretrainedConfig
):
def
_validate_dynamic_scaling_rope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys
=
{
"original_max_position_embeddings"
}
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
def
_validate_yarn_parameters
(
config
:
PretrainedConfig
):
def
_validate_yarn_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
}
optional_keys
=
{
"attention_factor"
,
"beta_fast"
,
"beta_slow"
}
optional_keys
=
{
"attention_factor"
,
"beta_fast"
,
"beta_slow"
,
"original_max_position_embeddings"
,
"mscale"
,
"mscale_all_dim"
,
}
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
...
...
@@ -437,17 +476,18 @@ def _validate_yarn_parameters(config: PretrainedConfig):
)
def
_validate_longrope_parameters
(
config
:
PretrainedConfig
):
def
_validate_longrope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"short_factor"
,
"long_factor"
}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys
=
{
"attention_factor"
,
"factor"
,
"original_max_position_embeddings"
}
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
,
ignore_keys
=
ignore_keys
)
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
int
((
config
.
hidden_size
//
config
.
num_attention_heads
)
*
partial_rotary_factor
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
short_factor
=
rope_scaling
.
get
(
"short_factor"
)
if
not
isinstance
(
short_factor
,
list
)
and
all
(
isinstance
(
x
,
(
int
,
float
))
for
x
in
short_factor
):
...
...
@@ -479,18 +519,19 @@ def _validate_longrope_parameters(config: PretrainedConfig):
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
attention_factor
=
rope_scaling
.
get
(
"attention_factor"
)
if
attention_factor
is
not
None
and
not
isinstance
(
attention_factor
,
float
)
or
attention_factor
<
0
:
logger
.
warning
(
f
"`rope_scaling`'s attention_factor field must be a float greater than 0, got
{
attention_factor
}
"
)
if
attention_factor
is
not
None
:
if
not
isinstance
(
attention_factor
,
float
)
or
attention_factor
<
0.0
:
logger
.
warning
(
f
"`rope_scaling`'s attention_factor field must be a float greater than 0, got
{
attention_factor
}
"
)
def
_validate_llama3_parameters
(
config
:
PretrainedConfig
):
def
_validate_llama3_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
,
"original_max_position_embeddings"
,
"low_freq_factor"
,
"high_freq_factor"
}
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
...
...
@@ -502,7 +543,7 @@ def _validate_llama3_parameters(config: PretrainedConfig):
logger
.
warning
(
f
"`rope_scaling`'s low_freq_factor field must be a float, got
{
low_freq_factor
}
"
)
if
high_freq_factor
is
None
or
not
isinstance
(
high_freq_factor
,
float
):
logger
.
warning
(
f
"`rope_scaling`'s high_freq_factor field must be a float, got
{
high_freq_factor
}
"
)
if
high_freq_factor
<
low_freq_factor
:
if
high_freq_factor
<
=
low_freq_factor
:
logger
.
warning
(
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
f
"
{
high_freq_factor
}
and low_freq_factor=
{
low_freq_factor
}
"
...
...
@@ -532,7 +573,7 @@ ROPE_VALIDATION_FUNCTIONS = {
}
def
rope_config_validation
(
config
:
PretrainedConfig
):
def
rope_config_validation
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
"""
Validate the RoPE config arguments, given a `PretrainedConfig` object
"""
...
...
@@ -544,8 +585,8 @@ def rope_config_validation(config: PretrainedConfig):
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
"default"
))
validation_fn
=
ROPE_VALIDATION_FUNCTIONS
.
get
(
rope_type
)
if
validation_fn
is
not
None
:
validation_fn
(
config
)
validation_fn
(
config
,
ignore_keys
=
ignore_keys
)
else
:
logger
.
warning
(
f
"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='
{
rope_type
}
'"
)
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment