"examples/hello_world/multinode_example/components/graph.py" did not exist on "df51a622dd6f9bf1e44e446535999210bad0797a"
Commit 70fdd8a2 authored by guanyu1's avatar guanyu1
Browse files

hunyuan分类模型适配

parent e0ba5f60
......@@ -3,7 +3,7 @@
import itertools
from abc import abstractmethod
from typing import Any, Literal, Optional, Union
from typing import Any, Iterable,Literal, Optional, Union
import vllm.envs as envs
import torch
import torch.nn as nn
......@@ -414,6 +414,53 @@ class ReplicatedLinear(LinearBase):
else:
self.register_parameter("bias", None)
self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
) -> set[str]:
"""Load parameters from (name, tensor) pairs into this layer."""
params = dict(self.named_parameters(recurse=False))
buffers = dict(self.named_buffers(recurse=False))
loaded: set[str] = set()
for weight_name, loaded_weight in weights:
# Default to the primary weight parameter if no suffix is given.
target_name = weight_name or "weight"
if target_name in params:
param = params[target_name]
weight_loader = getattr(param, "weight_loader",
self.weight_loader)
weight_loader(param, loaded_weight)
loaded.add(target_name)
continue
if target_name in buffers:
buffer = buffers[target_name]
if buffer.shape != loaded_weight.shape:
raise ValueError(
f"Shape mismatch when loading buffer '{target_name}': "
f"expected {buffer.shape}, got {loaded_weight.shape}")
buffer.copy_(loaded_weight)
loaded.add(target_name)
continue
attr = getattr(self, target_name, None)
if isinstance(attr, torch.Tensor):
if attr.shape != loaded_weight.shape:
raise ValueError(
f"Shape mismatch when loading tensor '{target_name}': "
f"expected {attr.shape}, got {loaded_weight.shape}")
attr.copy_(loaded_weight)
loaded.add(target_name)
continue
raise ValueError(
f"Unexpected weight '{target_name}' for "
f"{type(self).__name__}")
return loaded
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# If the weight on disk does not have a shape, give it one
......
......@@ -30,7 +30,12 @@ from vllm.utils import is_pin_memory_available
import vllm.envs as envs
logger = init_logger(__name__)
from ..models.adapters_custom.adapters_classify import (
as_hunyuan_seq_cls_model,
)
CLASSIFY_CLASSIFY_REGISTRY = {
"HunYuanForCausalLM": as_hunyuan_seq_cls_model,
}
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
......@@ -257,6 +262,9 @@ def _get_model_architecture(
logger.debug_once("Converting to embedding model.")
model_cls = as_embedding_model(model_cls)
elif convert_type == "classify":
if arch in CLASSIFY_CLASSIFY_REGISTRY.keys():
model_cls = CLASSIFY_CLASSIFY_REGISTRY[arch](model_cls)
else:
logger.debug_once("Converting to sequence classification model.")
model_cls = as_seq_cls_model(model_cls)
elif convert_type == "reward":
......
......@@ -36,6 +36,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler,
DispatchPooler, Pooler,
PoolingMethod, PoolingType)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
......@@ -553,23 +556,23 @@ class HunYuanDecoderLayer(nn.Module):
residual: Optional[torch.Tensor],
kv_states: Optional[Tuple[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
residual=hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states, ori_kv_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_states=kv_states,
)
hidden_states =residual+hidden_states
residual=hidden_states
hidden_states= self.post_attention_layernorm(hidden_states)
hidden_states=self.mlp(hidden_states)
hidden_states=hidden_states+residual
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual, ori_kv_states
......@@ -614,11 +617,13 @@ class HunYuanModel(nn.Module):
prefix=prefix,
),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
......@@ -662,7 +667,7 @@ class HunYuanModel(nn.Module):
"hidden_states": hidden_states,
"residual": residual
})
if not self.config.add_classification_head:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
......@@ -732,6 +737,10 @@ class HunYuanModel(nn.Module):
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
if name.startswith("norm."):
# Some checkpoints omit the final norm; treat as handled.
loaded_params.add(name)
continue
if "rotary_emb.inv_freq" in name:
continue
if "gate_proj_bias" in name:
......@@ -880,8 +889,16 @@ class HunYuanModel(nn.Module):
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if "norm.weight" in params_dict:
loaded_params.add("norm.weight")
return loaded_params
class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
......@@ -902,8 +919,11 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.pad_id = self.config.pad_id
self.model = HunYuanModel(vllm_config=vllm_config, prefix="model")
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
......@@ -924,6 +944,7 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
else:
self.lm_head = PPMissingLayer()
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
......@@ -958,17 +979,25 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
......@@ -992,11 +1021,10 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
skip_prefixes = []
if self.config.tie_word_embeddings:
skip_prefixes.append("lm_head.")
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment