Commit f386ba88 authored by zhuwenwen's avatar zhuwenwen
Browse files

[Models] support HunYuanForCausalLM

parent a9c37628
...@@ -1869,6 +1869,11 @@ def _get_and_verify_max_len( ...@@ -1869,6 +1869,11 @@ def _get_and_verify_max_len(
if rope_type == "yarn": if rope_type == "yarn":
derived_max_model_len = rope_scaling[ derived_max_model_len = rope_scaling[
"original_max_position_embeddings"] "original_max_position_embeddings"]
# see DynamicNTKAlphaRotaryEmbedding
if rope_scaling["type"] == "dynamic" and "alpha" in rope_scaling:
scaling_factor = 1
derived_max_model_len *= scaling_factor derived_max_model_len *= scaling_factor
if encoder_config and "max_seq_length" in encoder_config: if encoder_config and "max_seq_length" in encoder_config:
......
...@@ -137,9 +137,14 @@ def get_rope( ...@@ -137,9 +137,14 @@ def get_rope(
scaling_alpha, dtype) scaling_alpha, dtype)
elif "factor" in rope_scaling: elif "factor" in rope_scaling:
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding( if "alpha" in rope_scaling:
head_size, rotary_dim, max_position, base, is_neox_style, rotary_emb = DynamicNTKAlphaRotaryEmbedding(
scaling_factor, dtype) head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling["alpha"], dtype)
else:
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor, dtype)
else: else:
raise ValueError("Dynamic rope scaling must contain either " raise ValueError("Dynamic rope scaling must contain either "
"'alpha' or 'factor' field") "'alpha' or 'factor' field")
......
This diff is collapsed.
...@@ -889,7 +889,7 @@ class HunYuanModel(nn.Module): ...@@ -889,7 +889,7 @@ class HunYuanModel(nn.Module):
return loaded_params return loaded_params
class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -931,30 +931,6 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): ...@@ -931,30 +931,6 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
# Set MoE hyperparameters
self.expert_weights = []
self.num_expert_groups = 1
self.moe_layers: list[FusedMoE] = []
example_layer = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, HunYuanDecoderLayer)
if isinstance(layer.mlp, HunYuanSparseMoeBlock):
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_layer is None:
raise RuntimeError("No HunYuanMoE layer found in model.layers.")
self.num_moe_layers = len(self.moe_layers)
self.num_logical_experts = example_layer.n_logical_experts
self.num_physical_experts = example_layer.n_physical_experts
self.num_local_physical_experts = example_layer.n_local_physical_experts
self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts
def set_eplb_state( def set_eplb_state(
self, self,
expert_load_view: torch.Tensor, expert_load_view: torch.Tensor,
...@@ -1030,13 +1006,120 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): ...@@ -1030,13 +1006,120 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
) )
return loader.load_weights(weights) return loader.load_weights(weights)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
class HunYuanMoEV1Base(HunYuanV1Base, MixtureOfExperts):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = HunYuanModel(vllm_config=vllm_config, prefix="model")
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
else:
self.lm_head = PPMissingLayer()
# Set MoE hyperparameters
self.expert_weights = []
self.num_expert_groups = 1
self.moe_layers: list[FusedMoE] = [] # list[SharedFusedMoE] = []
example_layer = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, HunYuanDecoderLayer)
if isinstance(layer.mlp, HunYuanSparseMoeBlock):
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_layer is None:
raise RuntimeError("No HunYuanMoE layer found in model.layers.")
self.num_moe_layers = len(self.moe_layers)
self.num_logical_experts = example_layer.n_logical_experts
self.num_physical_experts = example_layer.n_physical_experts
self.num_local_physical_experts = example_layer.n_local_physical_experts
self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
self.expert_weights.append(layer.get_expert_weights())
# Register the expert weights.
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = (num_physical_experts -
self.num_logical_experts)
for layer in self.model.layers:
if isinstance(layer.mlp, HunYuanSparseMoeBlock):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping() return self.model.get_expert_mapping()
class HunYuanDenseV1ForCausalLM(HunYuanV1Base): class HunYuanDenseV1Base(HunYuanV1Base):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
class HunYuanDenseV1ForCausalLM(HunYuanDenseV1Base):
pass pass
class HunYuanMoEV1ForCausalLM(HunYuanV1Base): class HunYuanMoEV1ForCausalLM(HunYuanMoEV1Base):
pass pass
\ No newline at end of file
...@@ -100,6 +100,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -100,6 +100,7 @@ _TEXT_GENERATION_MODELS = {
"Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"), "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
"HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"), "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
"HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"), "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
"HunYuanForCausalLM": ("hunyuan", "HunYuanForCausalLM"),
"HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"), "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment