Unverified Commit 2390a2bc authored by Meng, Peng's avatar Meng, Peng Committed by GitHub
Browse files

Add Tencent HunYuanMoEV1 model support (#7549)

parent 16d76b9f
...@@ -890,6 +890,43 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding): ...@@ -890,6 +890,43 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
return query_out.type_as(query), key_out.type_as(key) return query_out.type_as(query), key_out.type_as(key)
class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_alpha: float,
dtype: torch.dtype,
) -> None:
self.scaling_alpha = scaling_alpha
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def _compute_cos_sin_cache(self) -> torch.Tensor:
max_len = self.max_position_embeddings
base = self.base * self.scaling_alpha ** (
self.rotary_dim / (self.rotary_dim - 2)
)
inv_freq = self._compute_inv_freq(base)
t = torch.arange(max_len, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
class MRotaryEmbedding(RotaryEmbedding): class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections.""" """Rotary Embedding with Multimodal Sections."""
...@@ -1234,15 +1271,26 @@ def get_rope( ...@@ -1234,15 +1271,26 @@ def get_rope(
) )
elif scaling_type == "dynamic": elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding( if "alpha" in rope_scaling:
head_size, rotary_emb = DynamicNTKAlphaRotaryEmbedding(
rotary_dim, head_size,
max_position, rotary_dim,
base, max_position,
is_neox_style, base,
scaling_factor, is_neox_style,
dtype, rope_scaling["alpha"],
) dtype,
)
else:
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
elif scaling_type == "yarn": elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"] original_max_position = rope_scaling["original_max_position_embeddings"]
......
# coding=utf-8
# Copyright 2024 The HunYuan team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only HunYuan model compatible with HuggingFace weights."""
import logging
import re
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.utils import add_prefix, is_hip
expert_distribution_recorder = ExpertDistributionRecorder()
def _is_moe(config: PretrainedConfig) -> bool:
if getattr(config, "num_experts", None) and (
(isinstance(config.num_experts, int) and config.num_experts > 1)
or (isinstance(config.num_experts, list) and max(config.num_experts) > 1)
):
return True
else:
return False
def _get_cla_factor(config: PretrainedConfig) -> int:
if not getattr(config, "use_cla", False):
return 1
return getattr(config, "cla_share_factor", 1)
class HunYuanMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
reduce_results=reduce_results,
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class HunYuanSparseMoeBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
layer_id: int = -1,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}."
)
# Get layer_id topk if config.moe_topk is a list
if isinstance(config.moe_topk, list):
assert layer_id >= 0
assert len(config.moe_topk) > layer_id
top_k = config.moe_topk[layer_id]
else:
top_k = config.moe_topk
# If it is moe, moe_intermediate_size is preferred
intermediate_size = config.intermediate_size
if config.moe_intermediate_size is not None:
intermediate_size = (
config.moe_intermediate_size
if isinstance(config.moe_intermediate_size, int)
else config.moe_intermediate_size[layer_id]
)
self.experts = FusedMoE(
num_experts=config.num_experts,
top_k=top_k,
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
reduce_results=False,
renormalize=True if top_k > 1 else False,
quant_config=quant_config,
)
self.gate = ReplicatedLinear(
config.hidden_size, config.num_experts, bias=False, quant_config=None
)
if config.use_mixed_mlp_moe > 0:
# Get layer_id num_shared_expert if config.num_shared_expert is a list
if isinstance(config.num_shared_expert, list):
assert layer_id >= 0
assert len(config.num_shared_expert) > layer_id
num_shared_expert = config.num_shared_expert[layer_id]
else:
num_shared_expert = config.num_shared_expert
self.shared_mlp = HunYuanMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size * num_shared_expert,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
)
else:
self.shared_mlp = None
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_mlp is not None:
shared_output = self.shared_mlp(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(orig_shape)
class HunYuanAttention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
attention_type: str = "self",
layer_id: int = -1,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(
config, "head_dim", self.hidden_size // self.total_num_heads
)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.use_qk_norm = getattr(config, "use_qk_norm", False)
self.attention_type = attention_type
self.layer_id = layer_id
if attention_type == "self":
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
elif attention_type == "cross":
self.q_proj = ColumnParallelLinear(
hidden_size,
hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
else:
raise RuntimeError("Not support attnention type")
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
is_neox_style = True
if quant_config is not None and quant_config.get_name() == "gguf":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
prefix=f"{prefix}.attn",
)
if self.use_qk_norm:
self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
kv_states: Optional[Tuple[torch.Tensor]] = None,
) -> torch.Tensor:
if self.attention_type == "self":
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
ori_k = k
if self.use_qk_norm:
# q = self.query_layernorm(q.view(-1, self.num_heads, self.head_dim).contiguous())
# k = self.key_layernorm(k.view(-1, self.num_kv_heads, self.head_dim).contiguous())
q = self.query_layernorm(q.reshape(-1, self.head_dim).contiguous())
k = self.key_layernorm(k.reshape(-1, self.head_dim).contiguous())
elif self.attention_type == "cross":
assert kv_states is not None
ori_k, v = kv_states # use last layer kv,
k = ori_k
q, _ = self.q_proj(hidden_states)
k_tmp = torch.empty_like(k) # Todo: reduant rotary embedding
q, _ = self.rotary_emb(positions, q, k_tmp)
if self.use_qk_norm:
q = self.query_layernorm(
q.view(-1, self.num_heads, self.head_dim).contiguous()
)
k = self.key_layernorm(
k.view(-1, self.num_kv_heads, self.head_dim).contiguous()
)
else:
raise RuntimeError("Not support attnention type")
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output, (ori_k, v)
class HunYuanDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
layer_id: int = -1,
) -> None:
super().__init__()
assert layer_id >= 0
self.layer_id = layer_id
self.hidden_size = config.hidden_size
self.intermediate_size = (
config.intermediate_size
if isinstance(config.intermediate_size, int)
else config.intermediate_size[layer_id]
)
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings
)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False
)
cla_factor = _get_cla_factor(config)
attention_type = (
"cross" if layer_id >= 0 and layer_id % cla_factor != 0 else "self"
)
self.self_attn = HunYuanAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
prefix=f"{prefix}.self_attn",
attention_type=attention_type,
layer_id=layer_id,
)
if _is_moe(config):
self.mlp = HunYuanSparseMoeBlock(
config=config,
quant_config=quant_config,
layer_id=layer_id,
)
else:
self.mlp = HunYuanMLP(
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
kv_states: Optional[Tuple[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states, ori_kv_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
kv_states=kv_states,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual, ori_kv_states
class HunYuanModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
HunYuanDecoderLayer(
config=config,
layer_id=layer_id,
quant_config=quant_config,
# prefix=prefix
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if input_embeds is not None:
hidden_states = input_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
cla_factor = _get_cla_factor(self.config)
prev_kv_states = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual, kv_states = layer(
positions,
hidden_states,
forward_batch,
residual,
prev_kv_states,
)
if False: # (i - self.start_layer) % cla_factor == 0:
prev_kv_states = kv_states
else:
prev_kv_states = None
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class HunYuanMoEV1ForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.model = HunYuanModel(config, quant_config, prefix="model")
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def _split_qkv_weight(self, qkv: torch.Tensor):
num_attention_heads = self.config.num_attention_heads
num_kv_heads = getattr(
self.config, "num_key_value_heads", self.config.num_attention_heads
)
num_key_value_groups = num_attention_heads // num_kv_heads
hidden_size = self.config.hidden_size
attention_head_dim = self.config.hidden_size // num_attention_heads
qkv = qkv.reshape(
num_kv_heads, num_key_value_groups + 2, attention_head_dim, hidden_size
)
q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1)
q = q.reshape(-1, hidden_size)
k = k.reshape(-1, hidden_size)
v = v.reshape(-1, hidden_size)
return torch.concat((q, k, v))
# return qkv.reshape((num_kv_heads, num_key_value_groups+2 , attention_head_dim, hidden_size)).permute((1,0,2,3)).reshape((-1, hidden_size)),
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
cla_factor = _get_cla_factor(self.config)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
num_attention_heads = self.config.num_attention_heads
num_kv_heads = getattr(
self.config, "num_key_value_heads", self.config.num_attention_heads
)
split_params_mapping = [
(".gate_up_proj", ".gate_and_up_proj", 2, [(1, 1), (0, 1)], None),
(
".qkv_proj",
".qkv_proj",
num_attention_heads + num_kv_heads * 2,
[("q", num_attention_heads), ("k", num_kv_heads), ("v", num_kv_heads)],
self._split_qkv_weight,
),
]
if _is_moe(self.config):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
else:
expert_params_mapping = {}
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "gate_proj_bias" in name:
name = name.replace("gate_proj_bias", "gate_proj.bias")
if "up_proj_bias" in name:
name = name.replace("up_proj_bias", "up_proj.bias")
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
is_found = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "mlp.experts" in name:
continue
# cross layer only have q_proj, skip qkv pack
if weight_name == ".q_proj":
match = re.search(r"layers\.\d+", name)
if match:
layer_id = int(match.group(0).split(".")[-1])
if cla_factor > 1 and layer_id % cla_factor != 0:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
is_found = True
break
if is_found:
continue
for param_name, weight_name, den, split_param, func in split_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
assert loaded_weight.shape[0] % den == 0
units = loaded_weight.shape[0] // den
param = params_dict[name]
weight_loader = param.weight_loader
offset = 0
for shard_id, num in split_param:
new_offset = offset + num * units
if func:
weight_loader(
param, func(loaded_weight)[offset:new_offset], shard_id
)
else:
weight_loader(param, loaded_weight[offset:new_offset], shard_id)
offset = new_offset
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if "mlp.gate.wg." in name:
name = name.replace("wg.", "")
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn
if is_hip():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.attn._kv_scale = scaling_factor
else:
raise RuntimeError(
"Self attention has no KV cache scaling " "factor attribute!"
)
EntryClass = HunYuanMoEV1ForCausalLM
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment