Commit 70fdd8a2 authored by guanyu1's avatar guanyu1
Browse files

hunyuan分类模型适配

parent e0ba5f60
......@@ -3,7 +3,7 @@
import itertools
from abc import abstractmethod
from typing import Any, Literal, Optional, Union
from typing import Any, Iterable,Literal, Optional, Union
import vllm.envs as envs
import torch
import torch.nn as nn
......@@ -414,6 +414,53 @@ class ReplicatedLinear(LinearBase):
else:
self.register_parameter("bias", None)
self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
) -> set[str]:
"""Load parameters from (name, tensor) pairs into this layer."""
params = dict(self.named_parameters(recurse=False))
buffers = dict(self.named_buffers(recurse=False))
loaded: set[str] = set()
for weight_name, loaded_weight in weights:
# Default to the primary weight parameter if no suffix is given.
target_name = weight_name or "weight"
if target_name in params:
param = params[target_name]
weight_loader = getattr(param, "weight_loader",
self.weight_loader)
weight_loader(param, loaded_weight)
loaded.add(target_name)
continue
if target_name in buffers:
buffer = buffers[target_name]
if buffer.shape != loaded_weight.shape:
raise ValueError(
f"Shape mismatch when loading buffer '{target_name}': "
f"expected {buffer.shape}, got {loaded_weight.shape}")
buffer.copy_(loaded_weight)
loaded.add(target_name)
continue
attr = getattr(self, target_name, None)
if isinstance(attr, torch.Tensor):
if attr.shape != loaded_weight.shape:
raise ValueError(
f"Shape mismatch when loading tensor '{target_name}': "
f"expected {attr.shape}, got {loaded_weight.shape}")
attr.copy_(loaded_weight)
loaded.add(target_name)
continue
raise ValueError(
f"Unexpected weight '{target_name}' for "
f"{type(self).__name__}")
return loaded
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# If the weight on disk does not have a shape, give it one
......
......@@ -30,7 +30,12 @@ from vllm.utils import is_pin_memory_available
import vllm.envs as envs
logger = init_logger(__name__)
from ..models.adapters_custom.adapters_classify import (
as_hunyuan_seq_cls_model,
)
CLASSIFY_CLASSIFY_REGISTRY = {
"HunYuanForCausalLM": as_hunyuan_seq_cls_model,
}
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
......@@ -257,8 +262,11 @@ def _get_model_architecture(
logger.debug_once("Converting to embedding model.")
model_cls = as_embedding_model(model_cls)
elif convert_type == "classify":
logger.debug_once("Converting to sequence classification model.")
model_cls = as_seq_cls_model(model_cls)
if arch in CLASSIFY_CLASSIFY_REGISTRY.keys():
model_cls = CLASSIFY_CLASSIFY_REGISTRY[arch](model_cls)
else:
logger.debug_once("Converting to sequence classification model.")
model_cls = as_seq_cls_model(model_cls)
elif convert_type == "reward":
logger.debug_once("Converting to reward model.")
model_cls = as_reward_model(model_cls)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import inspect
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.models.config import VerifyAndUpdateConfig
from vllm.transformers_utils.config import (get_hf_file_bytes,
get_hf_file_to_dict)
from ..interfaces_base import VllmModelForPooling, is_pooling_model
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
_T = TypeVar("_T", bound=type[nn.Module])
logger = init_logger(__name__)
_GENERATE_SUFFIXES = [
"ForCausalLM",
"ForConditionalGeneration",
"ChatModel",
"LMHeadModel",
]
def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
"""Load Sentence-Transformers Dense projection layers."""
try:
modules = get_hf_file_to_dict("modules.json", model_config.model,
model_config.revision)
if not modules:
return None
if isinstance(modules, dict):
modules = modules.get("modules", [])
dense_modules = [
m for m in modules
if m.get("type") == "sentence_transformers.models.Dense"
]
if not dense_modules:
return None
layers = []
for module in dense_modules:
folder = module.get("path", "")
config_path = f"{folder}/config.json" if folder else "config.json"
layer_config = get_hf_file_to_dict(config_path, model_config.model,
model_config.revision)
if not layer_config:
continue
linear = nn.Linear(layer_config.get("in_features", 768),
layer_config.get("out_features", 768),
bias=layer_config.get("bias", True),
dtype=model_config.head_dtype)
if not _load_dense_weights(linear, folder, model_config):
continue
layers.append(linear)
if act_name := layer_config.get("activation_function"):
layers.append(get_act_fn(act_name))
return nn.Sequential(*layers).to(dtype=model_config.head_dtype)
except Exception:
logger.exception("ST projector loading failed")
return None
def _load_dense_weights(linear: nn.Linear, folder: str,
model_config: "ModelConfig") -> bool:
"""Load weights using vLLM's weight_loader pattern."""
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader)
for filename in ["model.safetensors", "pytorch_model.bin"]:
file_path = f"{folder}/{filename}" if folder else filename
try:
file_bytes = get_hf_file_bytes(file_path, model_config.model,
model_config.revision)
if not file_bytes:
continue
if filename.endswith(".safetensors"):
from safetensors.torch import load as load_safetensors
state_dict = load_safetensors(file_bytes)
else:
import io
state_dict = torch.load(io.BytesIO(file_bytes),
map_location="cpu",
weights_only=True)
for weight_key in ["weight", "linear.weight", "dense.weight"]:
if weight_key in state_dict:
weight_loader = getattr(linear.weight, "weight_loader",
default_weight_loader)
weight_loader(linear.weight, state_dict[weight_key])
bias_key = weight_key.replace("weight", "bias")
if linear.bias is not None and bias_key in state_dict:
bias_loader = getattr(linear.bias, "weight_loader",
default_weight_loader)
bias_loader(linear.bias, state_dict[bias_key])
return True
except Exception:
logger.exception("Failed to load %s", filename)
continue
return False
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
model_name = orig_model_name
for generate_suffix in _GENERATE_SUFFIXES:
model_name = model_name.removesuffix(generate_suffix)
return model_name + pooling_suffix
def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T:
class CallVisitor(ast.NodeVisitor):
def __init__(self):
self.calls = []
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
self.calls.append(node.func.id)
self.generic_visit(node)
visitor = CallVisitor()
visitor.visit(ast.parse(inspect.getsource(orig_cls)))
if "init_vllm_registered_model" not in visitor.calls:
return None
class ModelForPooling(orig_cls, VllmModelForPooling):
is_pooling_model = True
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.pooler = self.get_language_model().pooler
return ModelForPooling # type: ignore
def _create_pooling_model_cls(orig_cls: _T) -> _T:
# Lazy import
from ..utils import AutoWeightsLoader, WeightsMapper
class ModelForPooling(orig_cls, VllmModelForPooling):
is_pooling_model = True
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.vllm_config = vllm_config
# These are not used in pooling models
for attr in ("lm_head", "logits_processor"):
if hasattr(self, attr):
delattr(self, attr)
# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "pooler", None):
self._init_pooler(vllm_config, prefix=prefix)
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
raise NotImplementedError
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO: Support uninitialized params tracking
# We have deleted this attribute, so don't load it
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
# If `*ForCausalLM` defines `load_weights` on the inner model
# and there are no other inner modules with parameters,
# we support loading from both `*Model` and `*ForCausalLM`
if hasattr(self, "model") and hasattr(self.model, "load_weights"):
# Whether only `self.model` contains parameters
model_is_only_param = all(
name == "model" or next(child.parameters(), None) is None
for name, child in self.named_children())
if model_is_only_param:
mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = mapper.apply(weights)
loaded_params = self.model.load_weights(weights)
loaded_params = {f"model.{name}" for name in loaded_params}
return loaded_params
# For most other models
if hasattr(orig_cls, "load_weights"):
return orig_cls.load_weights(self, weights) # type: ignore
# Fallback
else:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
return ModelForPooling # type: ignore
def as_embedding_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support embeddings.
By default, the embeddings of the whole prompt are extracted from the
normalized hidden state corresponding to the last token.
Note:
We assume that no extra layers are added to the original model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing embedding models
if is_pooling_model(cls):
return cls
# Lazy import
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
class ModelForEmbedding(_create_pooling_model_cls(cls)):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}, )
ModelForEmbedding.__name__ = \
_get_pooling_model_name(cls.__name__, "ForEmbedding")
return ModelForEmbedding # type: ignore
def as_hunyuan_seq_cls_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support classify and score tasks.
By default, the class probabilities are extracted from the softmaxed
hidden state corresponding to the last token.
Note:
We assume that the classification head is a single linear layer
stored as the attribute `score` of the top-level model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing classification models
if is_pooling_model(cls):
return cls
# Lazy import
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.pooler import (ClassifierPooler,
DispatchPooler, Pooler,
PoolingMethod, PoolingType,PoolerIdentity)
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.sequence import IntermediateTensors
from ..utils import get_model_hidden_size, maybe_prefix
class ModelForSequenceClassification(_create_pooling_model_cls(cls),
SupportsCrossEncoding):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
hidden_size = get_model_hidden_size(config)
pooler_config = vllm_config.model_config.pooler_config
if self.config.add_classification_head:
self.pool_head = ReplicatedLinear(
config.hidden_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
params_dtype=torch.float32,
prefix=maybe_prefix(prefix, "pool_head"),
return_bias=False,
)
self.pool_head2 = ReplicatedLinear(
config.hidden_size,
config.class_num,
bias=True,
quant_config=quant_config,
params_dtype=torch.float32,
prefix=maybe_prefix(prefix, "pool_head2"),
return_bias=True,
)
# 兼容 ForSequenceClassification:将 score 直接指向最终分类头
# 不再单独创建一层;pool_head2 即最终打分层
self.score = self.pool_head2
#Mark this instance as pooling-capable and build DispatchPooler
self.is_pooling_model = True
assert pooler_config is not None, (
"PoolerConfig must be provided to use classification head")
# Determine pooling type (fallback to config.pool_type)
pooling_type_str = (pooler_config.pooling_type
if pooler_config.pooling_type is not None
else getattr(config, "pool_type", "LAST")).upper()
if pooling_type_str == "LASTTOKEN":
pooling_type_str = "LAST"
pooling_type = PoolingType[pooling_type_str]
self.pooler = DispatchPooler({
"classify": ClassifierPooler(
pooling=PoolingMethod.from_pooling_type(pooling_type),
classifier=self._classifier,
act_fn=PoolerIdentity(),
)
})
def _classifier(self, x: torch.Tensor):
x= self.pool_head(x)
if isinstance(x, tuple):
x = x[0]
x = torch.tanh(x)
x= self.pool_head2(x)
if isinstance(x, tuple):
x = x[0]
return x
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return super().forward(input_ids, positions, intermediate_tensors,
inputs_embeds)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None)
method = getattr(self.config, "method", None)
if tokens is None and method is None:
return super().load_weights(weights)
else:
# Online convert ForCausalLM into
# ForSequenceClassification model.
return seq_cls_model_loader(self, weights)
ModelForSequenceClassification.__name__ = \
_get_pooling_model_name(cls.__name__, "ForSequenceClassification")
return ModelForSequenceClassification # type: ignore
def as_reward_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support reward modeling.
By default, we return the hidden states of each token directly.
Note:
We assume that no extra layers are added to the original model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing reward models
if is_pooling_model(cls):
return cls
# Lazy import
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
class ModelForReward(_create_pooling_model_cls(cls)):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, )
ModelForReward.__name__ = \
_get_pooling_model_name(cls.__name__, "ForReward")
return ModelForReward # type: ignore
class SequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
method = getattr(config, "method", None)
tokens = getattr(config, "classifier_from_token", None)
if method is None:
return
assert tokens is not None
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
if method == "from_2_way_softmax":
assert len(tokens) == 2
config.num_labels = 1
else:
config.num_labels = len(tokens)
# `llm as reranker` defaults to not using pad_token
use_pad_token = getattr(config, "use_pad_token", False)
config.use_pad_token = use_pad_token
def load_weights_using_from_2_way_softmax(
model, weights: Iterable[tuple[str, torch.Tensor]]):
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader)
from vllm.model_executor.models.utils import AutoWeightsLoader
model_config = model.vllm_config.model_config
tokens = getattr(model.config, "classifier_from_token", [])
tokens = cast(list[int], tokens)
assert len(tokens) == 2
if model.config.tie_word_embeddings:
model.lm_head = model.model.embed_tokens
else:
quant_config = model.vllm_config.quant_config
model.lm_head = ParallelLMHead(model.config.vocab_size,
model.config.hidden_size,
quant_config=quant_config)
loader = AutoWeightsLoader(model)
loaded_weights = loader.load_weights(weights)
from vllm.transformers_utils.tokenizer import get_tokenizer
tokenizer = get_tokenizer(model_config.tokenizer,
revision=model_config.tokenizer_revision,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code)
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
score_weight = model.lm_head.weight.data[[true_id]].to(
torch.float32) - model.lm_head.weight.data[[false_id]].to(
torch.float32)
param = model.score.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, score_weight)
del model.lm_head
loaded_weights.add("score.weight")
loaded_weights.discard("lm_head.weight")
return loaded_weights
def load_weights_no_post_processing(model,
weights: Iterable[tuple[str,
torch.Tensor]]):
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader)
from vllm.model_executor.models.utils import AutoWeightsLoader
model_config = model.vllm_config.model_config
tokens = getattr(model.config, "classifier_from_token", [])
tokens = cast(list[int], tokens)
assert len(tokens) > 0
if model.config.tie_word_embeddings:
model.lm_head = model.model.embed_tokens
else:
quant_config = model.vllm_config.quant_config
model.lm_head = ParallelLMHead(model.config.vocab_size,
model.config.hidden_size,
quant_config=quant_config)
loader = AutoWeightsLoader(model)
loaded_weights = loader.load_weights(weights)
from vllm.transformers_utils.tokenizer import get_tokenizer
tokenizer = get_tokenizer(model_config.tokenizer,
revision=model_config.tokenizer_revision,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code)
token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
score_weight = model.lm_head.weight.data[token_ids]
param = model.score.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, score_weight)
del model.lm_head
loaded_weights.add("score.weight")
loaded_weights.discard("lm_head.weight")
return loaded_weights
SEQ_CLS_LOAD_METHODS = {
"from_2_way_softmax": load_weights_using_from_2_way_softmax,
"no_post_processing": load_weights_no_post_processing,
}
def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
# Online convert ForCausalLM into ForSequenceClassification model.
# - from_2_way_softmax:
# - Qwen3ForCausalLM
# - Qwen3-Reranker
# - Qwen2ForCausalLM
# - mxbai-rerank-v2
# - no_post_processing:
# - GemmaForCausalLM
# - bge-reranker-v2-gemma
config = model.vllm_config.model_config.hf_config
method = getattr(config, "method", None)
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
return SEQ_CLS_LOAD_METHODS[method](model, weights)
......@@ -36,6 +36,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler,
DispatchPooler, Pooler,
PoolingMethod, PoolingType)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
......@@ -553,23 +556,23 @@ class HunYuanDecoderLayer(nn.Module):
residual: Optional[torch.Tensor],
kv_states: Optional[Tuple[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
residual=hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, ori_kv_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_states=kv_states,
)
hidden_states =residual+hidden_states
residual=hidden_states
hidden_states= self.post_attention_layernorm(hidden_states)
hidden_states=self.mlp(hidden_states)
hidden_states=hidden_states+residual
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual, ori_kv_states
......@@ -614,11 +617,13 @@ class HunYuanModel(nn.Module):
prefix=prefix,
),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
......@@ -650,7 +655,7 @@ class HunYuanModel(nn.Module):
residual,
prev_kv_states,
)
if (getattr(self.config, "use_cla", False)
and (i - self.start_layer) % cla_factor == 0):
prev_kv_states = kv_states
......@@ -662,8 +667,8 @@ class HunYuanModel(nn.Module):
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
if not self.config.add_classification_head:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def _split_qkv_weight(self, qkv: torch.Tensor):
......@@ -732,6 +737,10 @@ class HunYuanModel(nn.Module):
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
if name.startswith("norm."):
# Some checkpoints omit the final norm; treat as handled.
loaded_params.add(name)
continue
if "rotary_emb.inv_freq" in name:
continue
if "gate_proj_bias" in name:
......@@ -880,8 +889,16 @@ class HunYuanModel(nn.Module):
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if "norm.weight" in params_dict:
loaded_params.add("norm.weight")
return loaded_params
class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
......@@ -902,8 +919,11 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.pad_id = self.config.pad_id
self.model = HunYuanModel(vllm_config=vllm_config, prefix="model")
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
......@@ -924,6 +944,7 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
else:
self.lm_head = PPMissingLayer()
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
......@@ -957,18 +978,26 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
......@@ -992,14 +1021,13 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
skip_prefixes = []
if self.config.tie_word_embeddings:
skip_prefixes.append("lm_head.")
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment