Commit 70fdd8a2 authored by guanyu1's avatar guanyu1
Browse files

hunyuan分类模型适配

parent e0ba5f60
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import itertools import itertools
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Literal, Optional, Union from typing import Any, Iterable,Literal, Optional, Union
import vllm.envs as envs import vllm.envs as envs
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -414,6 +414,53 @@ class ReplicatedLinear(LinearBase): ...@@ -414,6 +414,53 @@ class ReplicatedLinear(LinearBase):
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod) self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
) -> set[str]:
"""Load parameters from (name, tensor) pairs into this layer."""
params = dict(self.named_parameters(recurse=False))
buffers = dict(self.named_buffers(recurse=False))
loaded: set[str] = set()
for weight_name, loaded_weight in weights:
# Default to the primary weight parameter if no suffix is given.
target_name = weight_name or "weight"
if target_name in params:
param = params[target_name]
weight_loader = getattr(param, "weight_loader",
self.weight_loader)
weight_loader(param, loaded_weight)
loaded.add(target_name)
continue
if target_name in buffers:
buffer = buffers[target_name]
if buffer.shape != loaded_weight.shape:
raise ValueError(
f"Shape mismatch when loading buffer '{target_name}': "
f"expected {buffer.shape}, got {loaded_weight.shape}")
buffer.copy_(loaded_weight)
loaded.add(target_name)
continue
attr = getattr(self, target_name, None)
if isinstance(attr, torch.Tensor):
if attr.shape != loaded_weight.shape:
raise ValueError(
f"Shape mismatch when loading tensor '{target_name}': "
f"expected {attr.shape}, got {loaded_weight.shape}")
attr.copy_(loaded_weight)
loaded.add(target_name)
continue
raise ValueError(
f"Unexpected weight '{target_name}' for "
f"{type(self).__name__}")
return loaded
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# If the weight on disk does not have a shape, give it one # If the weight on disk does not have a shape, give it one
......
...@@ -30,7 +30,12 @@ from vllm.utils import is_pin_memory_available ...@@ -30,7 +30,12 @@ from vllm.utils import is_pin_memory_available
import vllm.envs as envs import vllm.envs as envs
logger = init_logger(__name__) logger = init_logger(__name__)
from ..models.adapters_custom.adapters_classify import (
as_hunyuan_seq_cls_model,
)
CLASSIFY_CLASSIFY_REGISTRY = {
"HunYuanForCausalLM": as_hunyuan_seq_cls_model,
}
@contextlib.contextmanager @contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype): def set_default_torch_dtype(dtype: torch.dtype):
...@@ -257,6 +262,9 @@ def _get_model_architecture( ...@@ -257,6 +262,9 @@ def _get_model_architecture(
logger.debug_once("Converting to embedding model.") logger.debug_once("Converting to embedding model.")
model_cls = as_embedding_model(model_cls) model_cls = as_embedding_model(model_cls)
elif convert_type == "classify": elif convert_type == "classify":
if arch in CLASSIFY_CLASSIFY_REGISTRY.keys():
model_cls = CLASSIFY_CLASSIFY_REGISTRY[arch](model_cls)
else:
logger.debug_once("Converting to sequence classification model.") logger.debug_once("Converting to sequence classification model.")
model_cls = as_seq_cls_model(model_cls) model_cls = as_seq_cls_model(model_cls)
elif convert_type == "reward": elif convert_type == "reward":
......
...@@ -36,6 +36,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -36,6 +36,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler,
DispatchPooler, Pooler,
PoolingMethod, PoolingType)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
...@@ -553,23 +556,23 @@ class HunYuanDecoderLayer(nn.Module): ...@@ -553,23 +556,23 @@ class HunYuanDecoderLayer(nn.Module):
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
kv_states: Optional[Tuple[torch.Tensor]] = None, kv_states: Optional[Tuple[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states residual=hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states, ori_kv_states = self.self_attn( hidden_states, ori_kv_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_states=kv_states, kv_states=kv_states,
) )
hidden_states =residual+hidden_states
residual=hidden_states
hidden_states= self.post_attention_layernorm(hidden_states)
hidden_states=self.mlp(hidden_states)
hidden_states=hidden_states+residual
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual, ori_kv_states return hidden_states, residual, ori_kv_states
...@@ -614,11 +617,13 @@ class HunYuanModel(nn.Module): ...@@ -614,11 +617,13 @@ class HunYuanModel(nn.Module):
prefix=prefix, prefix=prefix,
), ),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -662,7 +667,7 @@ class HunYuanModel(nn.Module): ...@@ -662,7 +667,7 @@ class HunYuanModel(nn.Module):
"hidden_states": hidden_states, "hidden_states": hidden_states,
"residual": residual "residual": residual
}) })
if not self.config.add_classification_head:
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -732,6 +737,10 @@ class HunYuanModel(nn.Module): ...@@ -732,6 +737,10 @@ class HunYuanModel(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping() expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.startswith("norm."):
# Some checkpoints omit the final norm; treat as handled.
loaded_params.add(name)
continue
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if "gate_proj_bias" in name: if "gate_proj_bias" in name:
...@@ -880,8 +889,16 @@ class HunYuanModel(nn.Module): ...@@ -880,8 +889,16 @@ class HunYuanModel(nn.Module):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
if "norm.weight" in params_dict:
loaded_params.add("norm.weight")
return loaded_params return loaded_params
class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
...@@ -902,8 +919,11 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -902,8 +919,11 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.pad_id = self.config.pad_id
self.model = HunYuanModel(vllm_config=vllm_config, prefix="model") self.model = HunYuanModel(vllm_config=vllm_config, prefix="model")
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
...@@ -924,6 +944,7 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -924,6 +944,7 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
def set_eplb_state( def set_eplb_state(
self, self,
expert_load_view: torch.Tensor, expert_load_view: torch.Tensor,
...@@ -958,17 +979,25 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -958,17 +979,25 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
moe.n_redundant_experts = self.num_redundant_experts moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map() moe.experts.update_expert_map()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, intermediate_tensors, model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds) inputs_embeds)
return model_output return model_output
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -992,11 +1021,10 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -992,11 +1021,10 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( skip_prefixes = []
self, if self.config.tie_word_embeddings:
skip_prefixes=(["lm_head."] skip_prefixes.append("lm_head.")
if self.config.tie_word_embeddings else None), loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
)
return loader.load_weights(weights) return loader.load_weights(weights)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment