Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
f873558a
Commit
f873558a
authored
Feb 01, 2025
by
Azure
Browse files
update rope calculation; update modeling.py; update gate for moe
parent
5a50b346
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
406 additions
and
416 deletions
+406
-416
ktransformers/configs/config.yaml
ktransformers/configs/config.yaml
+1
-1
ktransformers/local_chat.py
ktransformers/local_chat.py
+2
-2
ktransformers/models/configuration_deepseek_v3.py
ktransformers/models/configuration_deepseek_v3.py
+51
-47
ktransformers/models/custom_cache.py
ktransformers/models/custom_cache.py
+4
-0
ktransformers/models/modeling_deepseek_v3.py
ktransformers/models/modeling_deepseek_v3.py
+230
-290
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+3
-2
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+5
-4
ktransformers/operators/gate.py
ktransformers/operators/gate.py
+3
-4
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
...s/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
+4
-4
ktransformers/server/config/config.py
ktransformers/server/config/config.py
+1
-1
ktransformers/util/modeling_rope_utils.py
ktransformers/util/modeling_rope_utils.py
+102
-61
No files found.
ktransformers/configs/config.yaml
View file @
f873558a
...
...
@@ -54,4 +54,4 @@ long_context:
token_step
:
local_chat
:
prompt_file
:
"
./ktransformers/p.txt"
\ No newline at end of file
prompt_file
:
"
"
\ No newline at end of file
ktransformers/local_chat.py
View file @
f873558a
...
...
@@ -15,7 +15,7 @@ from ktransformers.server.args import ArgumentParser
from
ktransformers.models.modeling_deepseek
import
DeepseekV2ForCausalLM
from
ktransformers.models.modeling_deepseekv3
import
DeepseekV3ForCausalLM
from
ktransformers.models.modeling_deepseek
_
v3
import
DeepseekV3ForCausalLM
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeForCausalLM
from
ktransformers.models.modeling_llama
import
LlamaForCausalLM
from
ktransformers.models.modeling_mixtral
import
MixtralForCausalLM
...
...
@@ -78,7 +78,7 @@ def local_chat():
else
:
content
+=
line
+
"
\n
"
if
content
==
""
:
if
config
.
prompt_file
==
None
or
config
.
prompt_file
==
""
:
if
not
config
.
prompt_file
:
content
=
"hi"
else
:
content
=
open
(
config
.
prompt_file
,
"r"
).
read
()
...
...
ktransformers/models/configuration_deepseekv3.py
→
ktransformers/models/configuration_deepseek
_
v3.py
View file @
f873558a
...
...
@@ -14,19 +14,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DeepSeekV3 model configuration
"""
"""DeepSeekV3 model configuration"""
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.modeling_rope_utils
import
rope_config_validation
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{}
class
DeepseekV3Config
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the DeepSeek-V3.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 129280):
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
...
...
@@ -39,8 +45,6 @@ class DeepseekV3Config(PretrainedConfig):
Dimension of the MoE representations.
num_hidden_layers (`int`, *optional*, defaults to 61):
Number of hidden layers in the Transformer decoder.
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
Number of nextn predict layers in the DeepSeekV3 Model.
num_attention_heads (`int`, *optional*, defaults to 128):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 128):
...
...
@@ -52,38 +56,35 @@ class DeepseekV3Config(PretrainedConfig):
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
n_shared_experts (`int`, *optional*, defaults to 1):
Number of shared experts
, None means dense model
.
Number of shared experts.
n_routed_experts (`int`, *optional*, defaults to 256):
Number of routed experts, None means dense model.
ep_size (`<fill_type>`, *optional*, defaults to 1): <fill_docstring>
Number of routed experts.
routed_scaling_factor (`float`, *optional*, defaults to 2.5):
Scaling factor or routed experts.
kv_lora_rank (`<fill_type>`, *optional*, defaults to 512): <fill_docstring>
q_lora_rank (`<fill_type>`, *optional*, defaults to 1536): <fill_docstring>
qk_rope_head_dim (`<fill_type>`, *optional*, defaults to 64): <fill_docstring>
v_head_dim (`<fill_type>`, *optional*, defaults to 128): <fill_docstring>
qk_nope_head_dim (`<fill_type>`, *optional*, defaults to 128): <fill_docstring>
topk_method (`str`, *optional*, defaults to `"noaux_tc"`):
Topk method used in routed gate.
kv_lora_rank (`int`, *optional*, defaults to 512):
Rank of the LoRA matrices for key and value projections.
q_lora_rank (`int`, *optional*, defaults to 1536):
Rank of the LoRA matrices for query projections.
qk_rope_head_dim (`int`, *optional*, defaults to 64):
Dimension of the query/key heads that use rotary position embeddings.
v_head_dim (`int`, *optional*, defaults to 128):
Dimension of the value heads.
qk_nope_head_dim (`int`, *optional*, defaults to 128):
Dimension of the query/key heads that don't use rotary position embeddings.
n_group (`int`, *optional*, defaults to 8):
Number of groups for routed experts.
topk_group (`int`, *optional*, defaults to 4):
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
num_experts_per_tok (`int`, *optional*, defaults to 8):
Number of selected experts, None means dense model.
moe_layer_freq (`int`, *optional*, defaults to 1):
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
first_k_dense_replace (`int`, *optional*, defaults to 3):
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
\--k dense layers--/
norm_topk_prob (`bool`, *optional*, defaults to `True`):
Whether to normalize the weights of the routed experts.
scoring_func (`str`, *optional*, defaults to `"sigmoid"`):
Method of computing expert weights.
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
Auxiliary loss weight coefficient.
Whether to compute the auxiliary loss for each individual sample.
seq_aux (`<fill_type>`, *optional*, defaults to `True`): <fill_docstring>
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 4096):
...
...
@@ -119,46 +120,49 @@ class DeepseekV3Config(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import DeepseekV3Model, DeepseekV3Config
>>> # Initializing a Deepseek-V3 style configuration
>>> configuration = DeepseekV3Config()
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"deepseek_v3"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
# Default tensor parallel plan for base model `DeepseekV3Model`
base_model_tp_plan
=
{
"layers.*.gate_proj"
:
"colwise"
,
"layers.*.up_proj"
:
"colwise"
,
"layers.*.down_proj"
:
"rowwise"
,
}
def
__init__
(
self
,
vocab_size
=
129280
,
hidden_size
=
7168
,
intermediate_size
=
18432
,
moe_intermediate_size
=
2048
,
moe_intermediate_size
=
2048
,
num_hidden_layers
=
61
,
num_nextn_predict_layers
=
1
,
num_attention_heads
=
128
,
num_key_value_heads
=
128
,
n_shared_experts
=
1
,
n_routed_experts
=
256
,
ep_size
=
1
,
routed_scaling_factor
=
2.5
,
kv_lora_rank
=
512
,
q_lora_rank
=
1536
,
qk_rope_head_dim
=
64
,
v_head_dim
=
128
,
qk_nope_head_dim
=
128
,
topk_method
=
'noaux_tc'
,
n_group
=
8
,
topk_group
=
4
,
num_experts_per_tok
=
8
,
moe_layer_freq
=
1
,
first_k_dense_replace
=
3
,
norm_topk_prob
=
True
,
scoring_func
=
'sigmoid'
,
aux_loss_alpha
=
0.001
,
seq_aux
=
True
,
n_shared_experts
=
1
,
n_routed_experts
=
256
,
routed_scaling_factor
=
2.5
,
kv_lora_rank
=
512
,
q_lora_rank
=
1536
,
qk_rope_head_dim
=
64
,
v_head_dim
=
128
,
qk_nope_head_dim
=
128
,
n_group
=
8
,
topk_group
=
4
,
num_experts_per_tok
=
8
,
first_k_dense_replace
=
3
,
norm_topk_prob
=
True
,
aux_loss_alpha
=
0.001
,
hidden_act
=
"silu"
,
max_position_embeddings
=
4096
,
initializer_range
=
0.02
,
...
...
@@ -173,7 +177,6 @@ class DeepseekV3Config(PretrainedConfig):
rope_scaling
=
None
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
mlp_bias
=
False
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
...
...
@@ -182,27 +185,24 @@ class DeepseekV3Config(PretrainedConfig):
self
.
intermediate_size
=
intermediate_size
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_nextn_predict_layers
=
num_nextn_predict_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
n_shared_experts
=
n_shared_experts
self
.
n_routed_experts
=
n_routed_experts
self
.
ep_size
=
ep_size
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
kv_lora_rank
=
kv_lora_rank
self
.
q_lora_rank
=
q_lora_rank
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
topk_method
=
topk_method
self
.
q_head_dim
=
qk_nope_head_dim
+
qk_rope_head_dim
self
.
head_dim
=
qk_rope_head_dim
self
.
n_group
=
n_group
self
.
topk_group
=
topk_group
self
.
num_experts_per_tok
=
num_experts_per_tok
self
.
moe_layer_freq
=
moe_layer_freq
self
.
first_k_dense_replace
=
first_k_dense_replace
self
.
norm_topk_prob
=
norm_topk_prob
self
.
scoring_func
=
scoring_func
self
.
aux_loss_alpha
=
aux_loss_alpha
self
.
seq_aux
=
seq_aux
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
...
...
@@ -217,7 +217,11 @@ class DeepseekV3Config(PretrainedConfig):
self
.
rope_scaling
=
rope_scaling
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
mlp_bias
=
mlp_bias
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if
self
.
rope_scaling
is
not
None
and
"type"
in
self
.
rope_scaling
:
self
.
rope_scaling
[
"rope_type"
]
=
self
.
rope_scaling
[
"type"
]
rope_config_validation
(
self
)
super
().
__init__
(
pad_token_id
=
pad_token_id
,
...
...
ktransformers/models/custom_cache.py
View file @
f873558a
...
...
@@ -135,3 +135,7 @@ class StaticCache(transformers.StaticCache):
# In-place ops prevent breaking the static address
self
.
key_cache
[
layer_idx
].
zero_
()
self
.
value_cache
[
layer_idx
].
zero_
()
def
get_max_cache_shape
(
self
)
->
Tuple
[
int
,
int
,
int
,
int
]:
"""Returns the maximum shape of the cache."""
return
self
.
max_cache_len
\ No newline at end of file
ktransformers/models/modeling_deepseekv3.py
→
ktransformers/models/modeling_deepseek
_
v3.py
View file @
f873558a
This diff is collapsed.
Click to expand it.
ktransformers/operators/attention.py
View file @
f873558a
...
...
@@ -13,7 +13,8 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.models.modeling_llama
import
LlamaRotaryEmbedding
from
ktransformers.models.modeling_deepseek
import
DeepseekV2Attention
,
apply_rotary_pos_emb
from
ktransformers.models.modeling_deepseekv3
import
DeepseekV3Attention
,
apply_rotary_pos_emb
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3Attention
from
ktransformers.models.modeling_deepseek_v3
import
apply_rotary_pos_emb
as
apply_rotary_pos_emb_v3
from
typing
import
Optional
,
Tuple
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
...
...
@@ -95,7 +96,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
_v3
(
q_pe
,
k_pe
,
cos
,
sin
)
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
...
...
ktransformers/operators/experts.py
View file @
f873558a
...
...
@@ -519,7 +519,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
from
ktransformers.models.modeling_deepseek
import
DeepseekV2MoE
from
ktransformers.models.modeling_deepseekv3
import
DeepseekV3MoE
from
ktransformers.models.modeling_deepseek
_
v3
import
DeepseekV3MoE
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeSparseMoeBlock
from
ktransformers.models.modeling_mixtral
import
MixtralSparseMoeBlock
...
...
@@ -734,9 +734,10 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
identity
=
hidden_states
orig_shape
=
hidden_states
.
shape
sequence_length
=
orig_shape
[
1
]
topk_idx
,
topk_weight
=
self
.
gate
(
hidden_states
)
topk_idx
,
topk_weight
,
router_logits
=
self
.
gate
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
# only for generate phase
if
sequence_length
==
1
and
hasattr
(
self
.
experts
.
generate_experts
,
"submit_for_one_decode"
)
and
torch
.
cuda
.
is_current_stream_capturing
():
self
.
experts
.
generate_experts
.
submit_for_one_decode
(
hidden_states
[
0
],
topk_idx
[
0
],
topk_weight
[
0
])
if
self
.
config
.
n_shared_experts
is
not
None
:
...
...
@@ -744,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y
=
self
.
experts
.
generate_experts
.
sync_for_one_decode
().
unsqueeze
(
0
)
y
+=
y_
y
.
resize_
(
*
orig_shape
)
return
y
return
y
,
router_logits
if
self
.
config
.
n_shared_experts
is
not
None
:
y_
=
self
.
shared_experts
(
identity
).
squeeze
(
0
)
...
...
@@ -767,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
)
if
self
.
config
.
n_shared_experts
is
not
None
:
y
+=
y_
return
y
return
y
,
router_logits
@
torch
.
no_grad
()
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
ktransformers/operators/gate.py
View file @
f873558a
...
...
@@ -16,7 +16,7 @@ from cpuinfer_ext.moe import MOEConfig, MOE
import
ctypes
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.models.modeling_deepseekv3
import
MoEGa
te
from
ktransformers.models.modeling_deepseek
_
v3
import
DeepseekV3TopkRou
te
r
from
ktransformers.util.utils
import
InferenceState
from
ktransformers.server.config.config
import
Config
from
transformers.activations
import
ACT2FN
...
...
@@ -118,11 +118,10 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
else
:
raise
ValueError
(
"Invalid weight type"
)
self
.
orig_module
.
weight
=
self
.
orig_module
.
weight
.
to
(
device
)
if
self
.
topk_method
==
"noaux_tc"
:
self
.
orig_module
.
e_score_correction_bias
=
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
)
self
.
orig_module
.
e_score_correction_bias
=
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
)
def
unload
(
self
):
if
self
.
weight
is
not
None
:
self
.
weight
=
None
if
self
.
topk_method
==
"noaux_tc"
:
if
self
.
e_score_correction_bias
is
not
None
:
self
.
e_score_correction_bias
=
None
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
View file @
f873558a
...
...
@@ -47,7 +47,7 @@
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseekv3.DeepseekV3MoE
class
:
ktransformers.models.modeling_deepseek
_
v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
...
...
@@ -55,7 +55,7 @@
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseekv3.DeepseekV3MoE
class
:
ktransformers.models.modeling_deepseek
_
v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
...
...
@@ -64,7 +64,7 @@
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseekv3.
MoEGa
te
class
:
ktransformers.models.modeling_deepseek
_
v3.
DeepseekV3TopkRou
te
r
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
...
...
@@ -72,7 +72,7 @@
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseekv3.
MoEGa
te
class
:
ktransformers.models.modeling_deepseek
_
v3.
DeepseekV3TopkRou
te
r
replace
:
class
:
ktransformers.operators.gate.KMoEGate
# mlp module with custom forward function
kwargs
:
...
...
ktransformers/server/config/config.py
View file @
f873558a
...
...
@@ -102,7 +102,7 @@ class Config(metaclass=Singleton):
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
max_chunk_size
=
self
.
model
.
get
(
"max_chunk_size"
,
2048
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
5
00
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
20
00
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
healing
=
self
.
model
.
get
(
"healing"
,
False
)
self
.
ban_strings
:
Optional
[
list
]
=
self
.
model
.
get
(
"ban_strings"
,
None
)
...
...
ktransformers/util/modeling_rope_utils.py
View file @
f873558a
...
...
@@ -58,7 +58,8 @@ def _compute_default_rope_parameters(
elif
config
is
not
None
:
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
int
((
config
.
hidden_size
//
config
.
num_attention_heads
)
*
partial_rotary_factor
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
attention_factor
=
1.0
# Unused in this type of RoPE
...
...
@@ -143,14 +144,15 @@ def _compute_dynamic_ntk_parameters(
elif
config
is
not
None
:
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
int
((
config
.
hidden_size
//
config
.
num_attention_heads
)
*
partial_rotary_factor
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
max_position_embeddings
=
config
.
max_position_embeddings
factor
=
config
.
rope_scaling
[
"factor"
]
attention_factor
=
1.0
# Unused in this type of RoPE
# seq_len: default to max_position_embeddings, e.g. at init time
seq_len
=
seq_len
if
seq_len
is
not
None
else
max_position_embeddings
seq_len
=
seq_len
if
seq_len
is
not
None
and
seq_len
>
max_position_embeddings
else
max_position_embeddings
# Compute the inverse frequencies
base
=
base
*
((
factor
*
seq_len
/
max_position_embeddings
)
-
(
factor
-
1
))
**
(
dim
/
(
dim
-
2
))
...
...
@@ -185,15 +187,33 @@ def _compute_yarn_parameters(
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
config
.
qk_rope_head_dim
max_position_embeddings
=
config
.
max_position_embeddings
head_dim
=
getattr
(
config
,
"qk_rope_head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
factor
=
config
.
rope_scaling
[
"factor"
]
attention_factor
=
config
.
rope_scaling
.
get
(
"attention_factor"
)
mscale
=
config
.
rope_scaling
.
get
(
"mscale"
)
mscale_all_dim
=
config
.
rope_scaling
.
get
(
"mscale_all_dim"
)
# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if
"original_max_position_embeddings"
in
config
.
rope_scaling
:
original_max_position_embeddings
=
config
.
rope_scaling
[
"original_max_position_embeddings"
]
factor
=
config
.
max_position_embeddings
/
original_max_position_embeddings
else
:
original_max_position_embeddings
=
config
.
max_position_embeddings
def
get_mscale
(
scale
,
mscale
=
1
):
if
scale
<=
1
:
return
1.0
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
# Sets the attention factor as suggested in the paper
attention_factor
=
config
.
rope_scaling
.
get
(
"attention_factor"
)
if
attention_factor
is
None
:
attention_factor
=
0.1
*
math
.
log
(
factor
)
+
1.0
if
mscale
and
mscale_all_dim
:
attention_factor
=
float
(
get_mscale
(
factor
,
mscale
)
/
get_mscale
(
factor
,
mscale_all_dim
))
else
:
attention_factor
=
get_mscale
(
factor
)
# Optional config options
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
...
...
@@ -211,7 +231,7 @@ def _compute_yarn_parameters(
high
=
math
.
ceil
(
find_correction_dim
(
high_rot
,
dim
,
base
,
max_position_embeddings
))
return
max
(
low
,
0
),
min
(
high
,
dim
-
1
)
def
linear_ramp_
mask
(
min
,
max
,
dim
):
def
linear_ramp_
factor
(
min
,
max
,
dim
):
if
min
==
max
:
max
+=
0.001
# Prevent singularity
...
...
@@ -219,16 +239,20 @@ def _compute_yarn_parameters(
ramp_func
=
torch
.
clamp
(
linear_func
,
0
,
1
)
return
ramp_func
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
# to expand the possible context length. In other words, interpolation = apply scaling factor.
pos_freqs
=
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
().
to
(
device
)
/
dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
factor
*
pos_freqs
)
low
,
high
=
find_correction_range
(
beta_fast
,
beta_slow
,
dim
,
base
,
max_position_embeddings
)
low
,
high
=
find_correction_range
(
beta_fast
,
beta_slow
,
dim
,
base
,
original_
max_position_embeddings
)
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_mask
=
1
-
linear_ramp_mask
(
low
,
high
,
dim
//
2
).
float
().
to
(
device
)
inv_freq
=
inv_freq_interpolation
*
(
1
-
inv_freq_mask
)
+
inv_freq_extrapolation
*
inv_freq_mask
inv_freq_extrapolation_factor
=
1
-
linear_ramp_factor
(
low
,
high
,
dim
//
2
).
float
().
to
(
device
)
inv_freq
=
(
inv_freq_interpolation
*
(
1
-
inv_freq_extrapolation_factor
)
+
inv_freq_extrapolation
*
inv_freq_extrapolation_factor
)
return
inv_freq
,
attention_factor
...
...
@@ -244,7 +268,7 @@ def _compute_longrope_parameters(
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length.
Unused for this type of RoPE.
The current sequence length.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
...
...
@@ -261,7 +285,8 @@ def _compute_longrope_parameters(
base
=
config
.
rope_theta
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
int
((
config
.
hidden_size
//
config
.
num_attention_heads
)
*
partial_rotary_factor
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
long_factor
=
config
.
rope_scaling
[
"long_factor"
]
short_factor
=
config
.
rope_scaling
[
"short_factor"
]
factor
=
config
.
rope_scaling
.
get
(
"factor"
)
...
...
@@ -271,22 +296,20 @@ def _compute_longrope_parameters(
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if
hasattr
(
config
,
"original_max_position_embeddings"
):
max_position_embeddings
=
config
.
original_max_position_embeddings
expanded_max_position_embeddings
=
config
.
max_position_embeddings
factor
=
expanded_max_position_embeddings
/
max_position_embeddings
original_max_position_embeddings
=
config
.
original_max_position_embeddings
factor
=
config
.
max_position_embeddings
/
config
.
original_max_position_embeddings
else
:
max_position_embeddings
=
config
.
max_position_embeddings
expanded_max_position_embeddings
=
max_position_embeddings
*
factor
original_max_position_embeddings
=
config
.
max_position_embeddings
# Sets the attention factor as suggested in the paper
if
attention_factor
is
None
:
if
factor
<=
1.0
:
attention_factor
=
1.0
else
:
attention_factor
=
math
.
sqrt
(
1
+
math
.
log
(
factor
)
/
math
.
log
(
max_position_embeddings
))
attention_factor
=
math
.
sqrt
(
1
+
math
.
log
(
factor
)
/
math
.
log
(
original_
max_position_embeddings
))
# Compute the inverse frequencies -- scaled based on the target sequence length
if
expanded_max_position_embeddings
>
max_position_embeddings
:
if
seq_len
and
seq_len
>
original_
max_position_embeddings
:
ext_factors
=
torch
.
tensor
(
long_factor
,
dtype
=
torch
.
float32
,
device
=
device
)
else
:
ext_factors
=
torch
.
tensor
(
short_factor
,
dtype
=
torch
.
float32
,
device
=
device
)
...
...
@@ -325,19 +348,18 @@ def _compute_llama3_parameters(
low_freq_wavelen
=
old_context_len
/
low_freq_factor
high_freq_wavelen
=
old_context_len
/
high_freq_factor
new_freqs
=
[]
for
freq
in
inv_freq
:
wavelen
=
2
*
math
.
pi
/
freq
if
wavelen
<
high_freq_wavelen
:
new_freqs
.
append
(
freq
)
elif
wavelen
>
low_freq_wavelen
:
new_freqs
.
append
(
freq
/
factor
)
else
:
assert
low_freq_wavelen
!=
high_freq_wavelen
smooth
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
high_freq_factor
-
low_freq_factor
)
new_freqs
.
append
((
1
-
smooth
)
*
freq
/
factor
+
smooth
*
freq
)
inv_freq
=
torch
.
tensor
(
new_freqs
,
dtype
=
inv_freq
.
dtype
,
device
=
inv_freq
.
device
)
return
inv_freq
,
attention_factor
wavelen
=
2
*
math
.
pi
/
inv_freq
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
inv_freq_llama
=
torch
.
where
(
wavelen
>
low_freq_wavelen
,
inv_freq
/
factor
,
inv_freq
)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
high_freq_factor
-
low_freq_factor
)
smoothed_inv_freq
=
(
1
-
smooth_factor
)
*
inv_freq_llama
/
factor
+
smooth_factor
*
inv_freq_llama
is_medium_freq
=
~
(
wavelen
<
high_freq_wavelen
)
*
~
(
wavelen
>
low_freq_wavelen
)
inv_freq_llama
=
torch
.
where
(
is_medium_freq
,
smoothed_inv_freq
,
inv_freq_llama
)
return
inv_freq_llama
,
attention_factor
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
...
...
@@ -353,12 +375,22 @@ ROPE_INIT_FUNCTIONS = {
}
def
_check_received_keys
(
rope_type
:
str
,
received_keys
:
set
,
required_keys
:
set
,
optional_keys
:
Optional
[
set
]
=
None
):
def
_check_received_keys
(
rope_type
:
str
,
received_keys
:
set
,
required_keys
:
set
,
optional_keys
:
Optional
[
set
]
=
None
,
ignore_keys
:
Optional
[
set
]
=
None
,
):
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
# BC: "rope_type" was originally "type" -- let's
gracefully handle i
t
if
"rope_type"
not
in
received_keys
and
"type"
in
received_keys
:
# BC: "rope_type" was originally "type" -- let's
check for "rope_type" when "type" is presen
t
if
"type"
in
received_keys
:
received_keys
-=
{
"type"
}
received_keys
.
add
(
"rope_type"
)
required_keys
.
add
(
"rope_type"
)
# Some models need to store model-specific keys, and we don't want to throw warning at them
if
ignore_keys
is
not
None
:
received_keys
-=
ignore_keys
missing_keys
=
required_keys
-
received_keys
if
missing_keys
:
...
...
@@ -372,47 +404,54 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set,
logger
.
warning
(
f
"Unrecognized keys in `rope_scaling` for 'rope_type'='
{
rope_type
}
':
{
unused_keys
}
"
)
def
_validate_default_rope_parameters
(
config
:
PretrainedConfig
):
def
_validate_default_rope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
}
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
ignore_keys
=
ignore_keys
)
def
_validate_linear_scaling_rope_parameters
(
config
:
PretrainedConfig
):
def
_validate_linear_scaling_rope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
}
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
def
_validate_dynamic_scaling_rope_parameters
(
config
:
PretrainedConfig
):
def
_validate_dynamic_scaling_rope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys
=
{
"original_max_position_embeddings"
}
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
def
_validate_yarn_parameters
(
config
:
PretrainedConfig
):
def
_validate_yarn_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
}
optional_keys
=
{
"attention_factor"
,
"beta_fast"
,
"beta_slow"
}
optional_keys
=
{
"attention_factor"
,
"beta_fast"
,
"beta_slow"
,
"original_max_position_embeddings"
,
"mscale"
,
"mscale_all_dim"
,
}
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
...
...
@@ -437,17 +476,18 @@ def _validate_yarn_parameters(config: PretrainedConfig):
)
def
_validate_longrope_parameters
(
config
:
PretrainedConfig
):
def
_validate_longrope_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"short_factor"
,
"long_factor"
}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys
=
{
"attention_factor"
,
"factor"
,
"original_max_position_embeddings"
}
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
optional_keys
,
ignore_keys
=
ignore_keys
)
partial_rotary_factor
=
config
.
partial_rotary_factor
if
hasattr
(
config
,
"partial_rotary_factor"
)
else
1.0
dim
=
int
((
config
.
hidden_size
//
config
.
num_attention_heads
)
*
partial_rotary_factor
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
dim
=
int
(
head_dim
*
partial_rotary_factor
)
short_factor
=
rope_scaling
.
get
(
"short_factor"
)
if
not
isinstance
(
short_factor
,
list
)
and
all
(
isinstance
(
x
,
(
int
,
float
))
for
x
in
short_factor
):
...
...
@@ -479,18 +519,19 @@ def _validate_longrope_parameters(config: PretrainedConfig):
logger
.
warning
(
f
"`rope_scaling`'s factor field must be a float >= 1, got
{
factor
}
"
)
attention_factor
=
rope_scaling
.
get
(
"attention_factor"
)
if
attention_factor
is
not
None
and
not
isinstance
(
attention_factor
,
float
)
or
attention_factor
<
0
:
logger
.
warning
(
f
"`rope_scaling`'s attention_factor field must be a float greater than 0, got
{
attention_factor
}
"
)
if
attention_factor
is
not
None
:
if
not
isinstance
(
attention_factor
,
float
)
or
attention_factor
<
0.0
:
logger
.
warning
(
f
"`rope_scaling`'s attention_factor field must be a float greater than 0, got
{
attention_factor
}
"
)
def
_validate_llama3_parameters
(
config
:
PretrainedConfig
):
def
_validate_llama3_parameters
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
rope_scaling
=
config
.
rope_scaling
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
None
))
# BC: "rope_type" was originally "type"
required_keys
=
{
"rope_type"
,
"factor"
,
"original_max_position_embeddings"
,
"low_freq_factor"
,
"high_freq_factor"
}
received_keys
=
set
(
rope_scaling
.
keys
())
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
)
_check_received_keys
(
rope_type
,
received_keys
,
required_keys
,
ignore_keys
=
ignore_keys
)
factor
=
rope_scaling
[
"factor"
]
if
factor
is
None
or
not
isinstance
(
factor
,
float
)
or
factor
<
1.0
:
...
...
@@ -502,7 +543,7 @@ def _validate_llama3_parameters(config: PretrainedConfig):
logger
.
warning
(
f
"`rope_scaling`'s low_freq_factor field must be a float, got
{
low_freq_factor
}
"
)
if
high_freq_factor
is
None
or
not
isinstance
(
high_freq_factor
,
float
):
logger
.
warning
(
f
"`rope_scaling`'s high_freq_factor field must be a float, got
{
high_freq_factor
}
"
)
if
high_freq_factor
<
low_freq_factor
:
if
high_freq_factor
<
=
low_freq_factor
:
logger
.
warning
(
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
f
"
{
high_freq_factor
}
and low_freq_factor=
{
low_freq_factor
}
"
...
...
@@ -532,7 +573,7 @@ ROPE_VALIDATION_FUNCTIONS = {
}
def
rope_config_validation
(
config
:
PretrainedConfig
):
def
rope_config_validation
(
config
:
PretrainedConfig
,
ignore_keys
:
Optional
[
set
]
=
None
):
"""
Validate the RoPE config arguments, given a `PretrainedConfig` object
"""
...
...
@@ -544,8 +585,8 @@ def rope_config_validation(config: PretrainedConfig):
rope_type
=
rope_scaling
.
get
(
"rope_type"
,
rope_scaling
.
get
(
"type"
,
"default"
))
validation_fn
=
ROPE_VALIDATION_FUNCTIONS
.
get
(
rope_type
)
if
validation_fn
is
not
None
:
validation_fn
(
config
)
validation_fn
(
config
,
ignore_keys
=
ignore_keys
)
else
:
logger
.
warning
(
f
"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='
{
rope_type
}
'"
)
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment