Commit f873558a authored by Azure's avatar Azure
Browse files

update rope calculation; update modeling.py; update gate for moe

parent 5a50b346
...@@ -54,4 +54,4 @@ long_context: ...@@ -54,4 +54,4 @@ long_context:
token_step: token_step:
local_chat: local_chat:
prompt_file: "./ktransformers/p.txt" prompt_file: ""
\ No newline at end of file \ No newline at end of file
...@@ -15,7 +15,7 @@ from ktransformers.server.args import ArgumentParser ...@@ -15,7 +15,7 @@ from ktransformers.server.args import ArgumentParser
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_deepseekv3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM
...@@ -78,7 +78,7 @@ def local_chat(): ...@@ -78,7 +78,7 @@ def local_chat():
else: else:
content += line + "\n" content += line + "\n"
if content == "": if content == "":
if config.prompt_file == None or config.prompt_file == "": if not config.prompt_file:
content = "hi" content = "hi"
else: else:
content = open(config.prompt_file, "r").read() content = open(config.prompt_file, "r").read()
......
...@@ -14,19 +14,25 @@ ...@@ -14,19 +14,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" DeepSeekV3 model configuration """ """DeepSeekV3 model configuration"""
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class DeepseekV3Config(PretrainedConfig): class DeepseekV3Config(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the DeepSeek-V3. defaults will yield a similar configuration to that of the DeepSeek-V3.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
vocab_size (`int`, *optional*, defaults to 129280): vocab_size (`int`, *optional*, defaults to 129280):
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
...@@ -39,8 +45,6 @@ class DeepseekV3Config(PretrainedConfig): ...@@ -39,8 +45,6 @@ class DeepseekV3Config(PretrainedConfig):
Dimension of the MoE representations. Dimension of the MoE representations.
num_hidden_layers (`int`, *optional*, defaults to 61): num_hidden_layers (`int`, *optional*, defaults to 61):
Number of hidden layers in the Transformer decoder. Number of hidden layers in the Transformer decoder.
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
Number of nextn predict layers in the DeepSeekV3 Model.
num_attention_heads (`int`, *optional*, defaults to 128): num_attention_heads (`int`, *optional*, defaults to 128):
Number of attention heads for each attention layer in the Transformer decoder. Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 128): num_key_value_heads (`int`, *optional*, defaults to 128):
...@@ -52,38 +56,35 @@ class DeepseekV3Config(PretrainedConfig): ...@@ -52,38 +56,35 @@ class DeepseekV3Config(PretrainedConfig):
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`. `num_attention_heads`.
n_shared_experts (`int`, *optional*, defaults to 1): n_shared_experts (`int`, *optional*, defaults to 1):
Number of shared experts, None means dense model. Number of shared experts.
n_routed_experts (`int`, *optional*, defaults to 256): n_routed_experts (`int`, *optional*, defaults to 256):
Number of routed experts, None means dense model. Number of routed experts.
ep_size (`<fill_type>`, *optional*, defaults to 1): <fill_docstring>
routed_scaling_factor (`float`, *optional*, defaults to 2.5): routed_scaling_factor (`float`, *optional*, defaults to 2.5):
Scaling factor or routed experts. Scaling factor or routed experts.
kv_lora_rank (`<fill_type>`, *optional*, defaults to 512): <fill_docstring> kv_lora_rank (`int`, *optional*, defaults to 512):
q_lora_rank (`<fill_type>`, *optional*, defaults to 1536): <fill_docstring> Rank of the LoRA matrices for key and value projections.
qk_rope_head_dim (`<fill_type>`, *optional*, defaults to 64): <fill_docstring> q_lora_rank (`int`, *optional*, defaults to 1536):
v_head_dim (`<fill_type>`, *optional*, defaults to 128): <fill_docstring> Rank of the LoRA matrices for query projections.
qk_nope_head_dim (`<fill_type>`, *optional*, defaults to 128): <fill_docstring> qk_rope_head_dim (`int`, *optional*, defaults to 64):
topk_method (`str`, *optional*, defaults to `"noaux_tc"`): Dimension of the query/key heads that use rotary position embeddings.
Topk method used in routed gate. v_head_dim (`int`, *optional*, defaults to 128):
Dimension of the value heads.
qk_nope_head_dim (`int`, *optional*, defaults to 128):
Dimension of the query/key heads that don't use rotary position embeddings.
n_group (`int`, *optional*, defaults to 8): n_group (`int`, *optional*, defaults to 8):
Number of groups for routed experts. Number of groups for routed experts.
topk_group (`int`, *optional*, defaults to 4): topk_group (`int`, *optional*, defaults to 4):
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
num_experts_per_tok (`int`, *optional*, defaults to 8): num_experts_per_tok (`int`, *optional*, defaults to 8):
Number of selected experts, None means dense model. Number of selected experts, None means dense model.
moe_layer_freq (`int`, *optional*, defaults to 1):
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
first_k_dense_replace (`int`, *optional*, defaults to 3): first_k_dense_replace (`int`, *optional*, defaults to 3):
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
\--k dense layers--/ \--k dense layers--/
norm_topk_prob (`bool`, *optional*, defaults to `True`): norm_topk_prob (`bool`, *optional*, defaults to `True`):
Whether to normalize the weights of the routed experts. Whether to normalize the weights of the routed experts.
scoring_func (`str`, *optional*, defaults to `"sigmoid"`):
Method of computing expert weights.
aux_loss_alpha (`float`, *optional*, defaults to 0.001): aux_loss_alpha (`float`, *optional*, defaults to 0.001):
Auxiliary loss weight coefficient. Auxiliary loss weight coefficient.
Whether to compute the auxiliary loss for each individual sample. Whether to compute the auxiliary loss for each individual sample.
seq_aux (`<fill_type>`, *optional*, defaults to `True`): <fill_docstring>
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder. The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 4096): max_position_embeddings (`int`, *optional*, defaults to 4096):
...@@ -119,46 +120,49 @@ class DeepseekV3Config(PretrainedConfig): ...@@ -119,46 +120,49 @@ class DeepseekV3Config(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention. Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0): attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities. The dropout ratio for the attention probabilities.
```python ```python
>>> from transformers import DeepseekV3Model, DeepseekV3Config >>> from transformers import DeepseekV3Model, DeepseekV3Config
>>> # Initializing a Deepseek-V3 style configuration >>> # Initializing a Deepseek-V3 style configuration
>>> configuration = DeepseekV3Config() >>> configuration = DeepseekV3Config()
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
```""" ```"""
model_type = "deepseek_v3" model_type = "deepseek_v3"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `DeepseekV3Model`
base_model_tp_plan = {
"layers.*.gate_proj": "colwise",
"layers.*.up_proj": "colwise",
"layers.*.down_proj": "rowwise",
}
def __init__( def __init__(
self, self,
vocab_size=129280, vocab_size=129280,
hidden_size=7168, hidden_size=7168,
intermediate_size=18432, intermediate_size=18432,
moe_intermediate_size = 2048, moe_intermediate_size=2048,
num_hidden_layers=61, num_hidden_layers=61,
num_nextn_predict_layers=1,
num_attention_heads=128, num_attention_heads=128,
num_key_value_heads=128, num_key_value_heads=128,
n_shared_experts = 1, n_shared_experts=1,
n_routed_experts = 256, n_routed_experts=256,
ep_size = 1, routed_scaling_factor=2.5,
routed_scaling_factor = 2.5, kv_lora_rank=512,
kv_lora_rank = 512, q_lora_rank=1536,
q_lora_rank = 1536, qk_rope_head_dim=64,
qk_rope_head_dim = 64, v_head_dim=128,
v_head_dim = 128, qk_nope_head_dim=128,
qk_nope_head_dim = 128, n_group=8,
topk_method = 'noaux_tc', topk_group=4,
n_group = 8, num_experts_per_tok=8,
topk_group = 4, first_k_dense_replace=3,
num_experts_per_tok = 8, norm_topk_prob=True,
moe_layer_freq = 1, aux_loss_alpha=0.001,
first_k_dense_replace = 3,
norm_topk_prob = True,
scoring_func = 'sigmoid',
aux_loss_alpha = 0.001,
seq_aux = True,
hidden_act="silu", hidden_act="silu",
max_position_embeddings=4096, max_position_embeddings=4096,
initializer_range=0.02, initializer_range=0.02,
...@@ -173,7 +177,6 @@ class DeepseekV3Config(PretrainedConfig): ...@@ -173,7 +177,6 @@ class DeepseekV3Config(PretrainedConfig):
rope_scaling=None, rope_scaling=None,
attention_bias=False, attention_bias=False,
attention_dropout=0.0, attention_dropout=0.0,
mlp_bias=False,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -182,27 +185,24 @@ class DeepseekV3Config(PretrainedConfig): ...@@ -182,27 +185,24 @@ class DeepseekV3Config(PretrainedConfig):
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_nextn_predict_layers = num_nextn_predict_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method self.q_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.head_dim = qk_rope_head_dim
self.n_group = n_group self.n_group = n_group
self.topk_group = topk_group self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility # for backward compatibility
if num_key_value_heads is None: if num_key_value_heads is None:
num_key_value_heads = num_attention_heads num_key_value_heads = num_attention_heads
...@@ -217,7 +217,11 @@ class DeepseekV3Config(PretrainedConfig): ...@@ -217,7 +217,11 @@ class DeepseekV3Config(PretrainedConfig):
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self.attention_bias = attention_bias self.attention_bias = attention_bias
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias # Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
......
...@@ -135,3 +135,7 @@ class StaticCache(transformers.StaticCache): ...@@ -135,3 +135,7 @@ class StaticCache(transformers.StaticCache):
# In-place ops prevent breaking the static address # In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_() self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
def get_max_cache_shape(self) -> Tuple[int, int, int, int]:
"""Returns the maximum shape of the cache."""
return self.max_cache_len
\ No newline at end of file
...@@ -13,7 +13,8 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config ...@@ -13,7 +13,8 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
from ktransformers.models.modeling_deepseekv3 import DeepseekV3Attention, apply_rotary_pos_emb from ktransformers.models.modeling_deepseek_v3 import DeepseekV3Attention
from ktransformers.models.modeling_deepseek_v3 import apply_rotary_pos_emb as apply_rotary_pos_emb_v3
from typing import Optional, Tuple from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
...@@ -95,7 +96,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention): ...@@ -95,7 +96,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(q_pe, position_ids) cos, sin = self.rotary_emb(q_pe, position_ids)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin) q_pe, k_pe = apply_rotary_pos_emb_v3(q_pe, k_pe, cos, sin)
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
......
...@@ -519,7 +519,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase): ...@@ -519,7 +519,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
from ktransformers.models.modeling_deepseek import DeepseekV2MoE from ktransformers.models.modeling_deepseek import DeepseekV2MoE
from ktransformers.models.modeling_deepseekv3 import DeepseekV3MoE from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
...@@ -734,9 +734,10 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): ...@@ -734,9 +734,10 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
identity = hidden_states identity = hidden_states
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
sequence_length = orig_shape[1] sequence_length = orig_shape[1]
topk_idx, topk_weight= self.gate(hidden_states) topk_idx, topk_weight, router_logits= self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
# only for generate phase
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
if self.config.n_shared_experts is not None: if self.config.n_shared_experts is not None:
...@@ -744,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): ...@@ -744,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0) y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)
y += y_ y += y_
y.resize_(*orig_shape) y.resize_(*orig_shape)
return y return y, router_logits
if self.config.n_shared_experts is not None: if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0) y_ = self.shared_experts(identity).squeeze(0)
...@@ -767,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): ...@@ -767,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
) )
if self.config.n_shared_experts is not None: if self.config.n_shared_experts is not None:
y += y_ y += y_
return y return y, router_logits
@torch.no_grad() @torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
......
...@@ -16,7 +16,7 @@ from cpuinfer_ext.moe import MOEConfig, MOE ...@@ -16,7 +16,7 @@ from cpuinfer_ext.moe import MOEConfig, MOE
import ctypes import ctypes
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.models.modeling_deepseekv3 import MoEGate from ktransformers.models.modeling_deepseek_v3 import DeepseekV3TopkRouter
from ktransformers.util.utils import InferenceState from ktransformers.util.utils import InferenceState
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
...@@ -118,11 +118,10 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase): ...@@ -118,11 +118,10 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
else: else:
raise ValueError("Invalid weight type") raise ValueError("Invalid weight type")
self.orig_module.weight = self.orig_module.weight.to(device) self.orig_module.weight = self.orig_module.weight.to(device)
if self.topk_method == "noaux_tc": self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device)
self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device)
def unload(self): def unload(self):
if self.weight is not None: if self.weight is not None:
self.weight = None self.weight = None
if self.topk_method == "noaux_tc": if self.e_score_correction_bias is not None:
self.e_score_correction_bias = None self.e_score_correction_bias = None
...@@ -47,7 +47,7 @@ ...@@ -47,7 +47,7 @@
- match: - match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$" name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseekv3.DeepseekV3MoE class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace: replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs: kwargs:
...@@ -55,7 +55,7 @@ ...@@ -55,7 +55,7 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp$" name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseekv3.DeepseekV3MoE class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace: replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs: kwargs:
...@@ -64,7 +64,7 @@ ...@@ -64,7 +64,7 @@
- match: - match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$" name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseekv3.MoEGate class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter
replace: replace:
class: ktransformers.operators.gate.KMoEGate class: ktransformers.operators.gate.KMoEGate
kwargs: kwargs:
...@@ -72,7 +72,7 @@ ...@@ -72,7 +72,7 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$" name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseekv3.MoEGate class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter
replace: replace:
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs: kwargs:
......
...@@ -102,7 +102,7 @@ class Config(metaclass=Singleton): ...@@ -102,7 +102,7 @@ class Config(metaclass=Singleton):
self.total_context = self.model.get("total_context", 2**18) self.total_context = self.model.get("total_context", 2**18)
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1) self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
self.max_chunk_size = self.model.get("max_chunk_size", 2048) self.max_chunk_size = self.model.get("max_chunk_size", 2048)
self.max_new_tokens = self.model.get("max_new_tokens", 500) self.max_new_tokens = self.model.get("max_new_tokens", 2000)
self.json_mode = self.model.get("json_mode", False) self.json_mode = self.model.get("json_mode", False)
self.healing = self.model.get("healing", False) self.healing = self.model.get("healing", False)
self.ban_strings: Optional[list] = self.model.get("ban_strings", None) self.ban_strings: Optional[list] = self.model.get("ban_strings", None)
......
...@@ -58,7 +58,8 @@ def _compute_default_rope_parameters( ...@@ -58,7 +58,8 @@ def _compute_default_rope_parameters(
elif config is not None: elif config is not None:
base = config.rope_theta base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
attention_factor = 1.0 # Unused in this type of RoPE attention_factor = 1.0 # Unused in this type of RoPE
...@@ -143,14 +144,15 @@ def _compute_dynamic_ntk_parameters( ...@@ -143,14 +144,15 @@ def _compute_dynamic_ntk_parameters(
elif config is not None: elif config is not None:
base = config.rope_theta base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
max_position_embeddings = config.max_position_embeddings max_position_embeddings = config.max_position_embeddings
factor = config.rope_scaling["factor"] factor = config.rope_scaling["factor"]
attention_factor = 1.0 # Unused in this type of RoPE attention_factor = 1.0 # Unused in this type of RoPE
# seq_len: default to max_position_embeddings, e.g. at init time # seq_len: default to max_position_embeddings, e.g. at init time
seq_len = seq_len if seq_len is not None else max_position_embeddings seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
# Compute the inverse frequencies # Compute the inverse frequencies
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
...@@ -185,15 +187,33 @@ def _compute_yarn_parameters( ...@@ -185,15 +187,33 @@ def _compute_yarn_parameters(
base = config.rope_theta base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = config.qk_rope_head_dim head_dim = getattr(config, "qk_rope_head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
max_position_embeddings = config.max_position_embeddings
factor = config.rope_scaling["factor"] factor = config.rope_scaling["factor"]
attention_factor = config.rope_scaling.get("attention_factor")
mscale = config.rope_scaling.get("mscale")
mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if "original_max_position_embeddings" in config.rope_scaling:
original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"]
factor = config.max_position_embeddings / original_max_position_embeddings
else:
original_max_position_embeddings = config.max_position_embeddings
def get_mscale(scale, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
# Sets the attention factor as suggested in the paper # Sets the attention factor as suggested in the paper
attention_factor = config.rope_scaling.get("attention_factor")
if attention_factor is None: if attention_factor is None:
attention_factor = 0.1 * math.log(factor) + 1.0 if mscale and mscale_all_dim:
attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
else:
attention_factor = get_mscale(factor)
# Optional config options # Optional config options
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
...@@ -211,7 +231,7 @@ def _compute_yarn_parameters( ...@@ -211,7 +231,7 @@ def _compute_yarn_parameters(
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1) return max(low, 0), min(high, dim - 1)
def linear_ramp_mask(min, max, dim): def linear_ramp_factor(min, max, dim):
if min == max: if min == max:
max += 0.001 # Prevent singularity max += 0.001 # Prevent singularity
...@@ -219,16 +239,20 @@ def _compute_yarn_parameters( ...@@ -219,16 +239,20 @@ def _compute_yarn_parameters(
ramp_func = torch.clamp(linear_func, 0, 1) ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func return ramp_func
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
# to expand the possible context length. In other words, interpolation = apply scaling factor.
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim) pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs) inv_freq_interpolation = 1.0 / (factor * pos_freqs)
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings) low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)
# Get n-dimensional rotational scaling corrected for extrapolation # Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_mask = 1 - linear_ramp_mask(low, high, dim // 2).float().to(device) inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask inv_freq = (
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
)
return inv_freq, attention_factor return inv_freq, attention_factor
...@@ -244,7 +268,7 @@ def _compute_longrope_parameters( ...@@ -244,7 +268,7 @@ def _compute_longrope_parameters(
device (`torch.device`): device (`torch.device`):
The device to use for initialization of the inverse frequencies. The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*): seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE. The current sequence length.
rope_kwargs (`Dict`, *optional*): rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns: Returns:
...@@ -261,7 +285,8 @@ def _compute_longrope_parameters( ...@@ -261,7 +285,8 @@ def _compute_longrope_parameters(
base = config.rope_theta base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
long_factor = config.rope_scaling["long_factor"] long_factor = config.rope_scaling["long_factor"]
short_factor = config.rope_scaling["short_factor"] short_factor = config.rope_scaling["short_factor"]
factor = config.rope_scaling.get("factor") factor = config.rope_scaling.get("factor")
...@@ -271,22 +296,20 @@ def _compute_longrope_parameters( ...@@ -271,22 +296,20 @@ def _compute_longrope_parameters(
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`. # values to compute the default attention scaling factor, instead of using `factor`.
if hasattr(config, "original_max_position_embeddings"): if hasattr(config, "original_max_position_embeddings"):
max_position_embeddings = config.original_max_position_embeddings original_max_position_embeddings = config.original_max_position_embeddings
expanded_max_position_embeddings = config.max_position_embeddings factor = config.max_position_embeddings / config.original_max_position_embeddings
factor = expanded_max_position_embeddings / max_position_embeddings
else: else:
max_position_embeddings = config.max_position_embeddings original_max_position_embeddings = config.max_position_embeddings
expanded_max_position_embeddings = max_position_embeddings * factor
# Sets the attention factor as suggested in the paper # Sets the attention factor as suggested in the paper
if attention_factor is None: if attention_factor is None:
if factor <= 1.0: if factor <= 1.0:
attention_factor = 1.0 attention_factor = 1.0
else: else:
attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings)) attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))
# Compute the inverse frequencies -- scaled based on the target sequence length # Compute the inverse frequencies -- scaled based on the target sequence length
if expanded_max_position_embeddings > max_position_embeddings: if seq_len and seq_len > original_max_position_embeddings:
ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device) ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
else: else:
ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device) ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
...@@ -325,19 +348,18 @@ def _compute_llama3_parameters( ...@@ -325,19 +348,18 @@ def _compute_llama3_parameters(
low_freq_wavelen = old_context_len / low_freq_factor low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in inv_freq: wavelen = 2 * math.pi / inv_freq
wavelen = 2 * math.pi / freq # wavelen < high_freq_wavelen: do nothing
if wavelen < high_freq_wavelen: # wavelen > low_freq_wavelen: divide by factor
new_freqs.append(freq) inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
elif wavelen > low_freq_wavelen: # otherwise: interpolate between the two, using a smooth factor
new_freqs.append(freq / factor) smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
else: smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
assert low_freq_wavelen != high_freq_wavelen is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device) return inv_freq_llama, attention_factor
return inv_freq, attention_factor
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
...@@ -353,12 +375,22 @@ ROPE_INIT_FUNCTIONS = { ...@@ -353,12 +375,22 @@ ROPE_INIT_FUNCTIONS = {
} }
def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None): def _check_received_keys(
rope_type: str,
received_keys: set,
required_keys: set,
optional_keys: Optional[set] = None,
ignore_keys: Optional[set] = None,
):
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys""" """Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
# BC: "rope_type" was originally "type" -- let's gracefully handle it # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
if "rope_type" not in received_keys and "type" in received_keys: if "type" in received_keys:
received_keys -= {"type"} received_keys -= {"type"}
received_keys.add("rope_type") required_keys.add("rope_type")
# Some models need to store model-specific keys, and we don't want to throw warning at them
if ignore_keys is not None:
received_keys -= ignore_keys
missing_keys = required_keys - received_keys missing_keys = required_keys - received_keys
if missing_keys: if missing_keys:
...@@ -372,47 +404,54 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, ...@@ -372,47 +404,54 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set,
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
def _validate_default_rope_parameters(config: PretrainedConfig): def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type"} required_keys = {"rope_type"}
received_keys = set(rope_scaling.keys()) received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys) _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig): def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"} required_keys = {"rope_type", "factor"}
received_keys = set(rope_scaling.keys()) received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys) _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
factor = rope_scaling["factor"] factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0: if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig): def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"} required_keys = {"rope_type", "factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings` # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"original_max_position_embeddings"} optional_keys = {"original_max_position_embeddings"}
received_keys = set(rope_scaling.keys()) received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys) _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
factor = rope_scaling["factor"] factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0: if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
def _validate_yarn_parameters(config: PretrainedConfig): def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"} required_keys = {"rope_type", "factor"}
optional_keys = {"attention_factor", "beta_fast", "beta_slow"} optional_keys = {
"attention_factor",
"beta_fast",
"beta_slow",
"original_max_position_embeddings",
"mscale",
"mscale_all_dim",
}
received_keys = set(rope_scaling.keys()) received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys) _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
factor = rope_scaling["factor"] factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0: if factor is None or not isinstance(factor, float) or factor < 1.0:
...@@ -437,17 +476,18 @@ def _validate_yarn_parameters(config: PretrainedConfig): ...@@ -437,17 +476,18 @@ def _validate_yarn_parameters(config: PretrainedConfig):
) )
def _validate_longrope_parameters(config: PretrainedConfig): def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "short_factor", "long_factor"} required_keys = {"rope_type", "short_factor", "long_factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings` # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
received_keys = set(rope_scaling.keys()) received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys) _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
short_factor = rope_scaling.get("short_factor") short_factor = rope_scaling.get("short_factor")
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
...@@ -479,18 +519,19 @@ def _validate_longrope_parameters(config: PretrainedConfig): ...@@ -479,18 +519,19 @@ def _validate_longrope_parameters(config: PretrainedConfig):
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
attention_factor = rope_scaling.get("attention_factor") attention_factor = rope_scaling.get("attention_factor")
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: if attention_factor is not None:
logger.warning( if not isinstance(attention_factor, float) or attention_factor < 0.0:
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" logger.warning(
) f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
def _validate_llama3_parameters(config: PretrainedConfig): def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"} required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
received_keys = set(rope_scaling.keys()) received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys) _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
factor = rope_scaling["factor"] factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0: if factor is None or not isinstance(factor, float) or factor < 1.0:
...@@ -502,7 +543,7 @@ def _validate_llama3_parameters(config: PretrainedConfig): ...@@ -502,7 +543,7 @@ def _validate_llama3_parameters(config: PretrainedConfig):
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
if high_freq_factor is None or not isinstance(high_freq_factor, float): if high_freq_factor is None or not isinstance(high_freq_factor, float):
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
if high_freq_factor < low_freq_factor: if high_freq_factor <= low_freq_factor:
logger.warning( logger.warning(
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
f"{high_freq_factor} and low_freq_factor={low_freq_factor}" f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
...@@ -532,7 +573,7 @@ ROPE_VALIDATION_FUNCTIONS = { ...@@ -532,7 +573,7 @@ ROPE_VALIDATION_FUNCTIONS = {
} }
def rope_config_validation(config: PretrainedConfig): def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
""" """
Validate the RoPE config arguments, given a `PretrainedConfig` object Validate the RoPE config arguments, given a `PretrainedConfig` object
""" """
...@@ -544,8 +585,8 @@ def rope_config_validation(config: PretrainedConfig): ...@@ -544,8 +585,8 @@ def rope_config_validation(config: PretrainedConfig):
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
if validation_fn is not None: if validation_fn is not None:
validation_fn(config) validation_fn(config, ignore_keys=ignore_keys)
else: else:
logger.warning( logger.warning(
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
) )
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment