Unverified Commit 51ba8395 authored by Calvin Chen's avatar Calvin Chen Committed by GitHub
Browse files

[Model] use AutoWeightsLoader for bart (#18299)


Signed-off-by: default avatarcalvin chen <120380290@qq.com>
parent d1fb65bd
...@@ -46,7 +46,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -46,7 +46,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsQuant, SupportsV0Only from .interfaces import SupportsQuant, SupportsV0Only
from .utils import maybe_prefix from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -700,7 +700,8 @@ class BartDecoder(nn.Module): ...@@ -700,7 +700,8 @@ class BartDecoder(nn.Module):
class BartModel(nn.Module, SupportsQuant): class BartModel(nn.Module, SupportsQuant):
_tied_weights_keys = [ _tied_weights_keys = [
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight" "encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
] ]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
...@@ -763,10 +764,54 @@ class BartModel(nn.Module, SupportsQuant): ...@@ -763,10 +764,54 @@ class BartModel(nn.Module, SupportsQuant):
return decoder_outputs return decoder_outputs
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
other_weights = []
loaded_stacked_params = []
model_params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name not in model_params_dict:
continue
param = model_params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_stacked_params.append(name)
break
else:
if name in model_params_dict:
other_weights.append((name, loaded_weight))
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(other_weights)
loaded_params.update(loaded_stacked_params)
return loaded_params
class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} hf_to_vllm_mapper = WeightsMapper(
base_model_prefix = "model" orig_to_new_prefix={
"decoder.": "model.decoder.",
"encoder.": "model.encoder.",
"shared.": "model.shared."
},
orig_to_new_substr={
"beta": "bias",
"gamma": "weight",
"LayerNorm": "layernorm",
},
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
...@@ -789,7 +834,6 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): ...@@ -789,7 +834,6 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
self.lm_head = BartParallelLMHead(config.vocab_size, self.lm_head = BartParallelLMHead(config.vocab_size,
config.d_model, config.d_model,
embed_scale=embed_scale) embed_scale=embed_scale)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
...@@ -828,61 +872,12 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): ...@@ -828,61 +872,12 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
sampling_metadata) sampling_metadata)
return logits return logits
stacked_params_mapping = { def load_weights(self, weights: Iterable[tuple[str,
"q_proj": { torch.Tensor]]) -> set[str]:
"param_name": "qkv_proj",
"shard_id": "q",
},
"k_proj": {
"param_name": "qkv_proj",
"shard_id": "k",
},
"v_proj": {
"param_name": "qkv_proj",
"shard_id": "v",
},
}
params_mapping = {
"beta": "bias",
"gamma": "weight",
"LayerNorm": "layernorm",
}
def _rename_key(self, key: str):
prefix = f"{self.base_model_prefix}."
key = key[len(prefix):] if key.startswith(prefix) else key
for src, dst in self.params_mapping.items():
key = key.replace(src, dst)
return key
def _rename_stacked_param(
self,
name: str,
) -> tuple[str, Optional[str]]:
for key, mapping in self.stacked_params_mapping.items():
if key in name:
name = name.replace(key, mapping["param_name"])
return name, mapping["shard_id"]
return name, None
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
model_params_dict = dict(self.model.named_parameters())
top_params_dict = dict(self.named_parameters())
weights_tuple_list = list(weights) weights_tuple_list = list(weights)
shared_embedding_weight = None shared_embedding_weight = None
shared_embedding_shard_id = None
for name, loaded_weight in weights_tuple_list: for name, loaded_weight in weights_tuple_list:
name = self._rename_key(name)
name, shard_id = self._rename_stacked_param(name)
if ('shared.weight' in name if ('shared.weight' in name
or 'encoder.embed_tokens.weight' in name or 'encoder.embed_tokens.weight' in name
or 'decoder.embed_tokens.weight' in name or 'decoder.embed_tokens.weight' in name
...@@ -890,49 +885,24 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): ...@@ -890,49 +885,24 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
assert shared_embedding_weight is None, ( assert shared_embedding_weight is None, (
"Conflicting embedding weights.") "Conflicting embedding weights.")
shared_embedding_weight = loaded_weight shared_embedding_weight = loaded_weight
shared_embedding_shard_id = shard_id
else:
# Skip the specific downstream task weight.
if name.startswith('cls.'):
continue
# use Pooler instead.
if name.startswith('pooler.'):
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in model_params_dict:
continue
param = model_params_dict[name] loader = AutoWeightsLoader(
weight_loader = getattr(param, "weight_loader", self,
default_weight_loader) skip_prefixes=(["cls.", "pooler."]),
if shard_id: )
weight_loader(param, loaded_weight, shard_id) loaded_params = loader.load_weights(weights_tuple_list,
else: mapper=self.hf_to_vllm_mapper)
weight_loader(param, loaded_weight)
if shared_embedding_weight is not None:
# Assign shared weight values weight_loader = getattr(self.lm_head.weight, "weight_loader",
encoder_in_param = model_params_dict['encoder.embed_tokens.weight'] default_weight_loader)
encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader", weight_loader(self.lm_head.weight, shared_embedding_weight)
default_weight_loader)
self.model.encoder.embed_tokens.weight = self.lm_head.weight
decoder_in_param = model_params_dict['decoder.embed_tokens.weight'] self.model.decoder.embed_tokens.weight = self.lm_head.weight
decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader", loaded_params.update({
default_weight_loader) 'model.encoder.embed_tokens.weight', 'lm_head.weight',
'model.decoder.embed_tokens.weight'
lm_head_in_param = top_params_dict['lm_head.weight'] })
lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader",
default_weight_loader) return loaded_params
assert shared_embedding_weight is not None
if shared_embedding_shard_id:
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight,
shared_embedding_shard_id)
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight,
shared_embedding_shard_id)
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight,
shared_embedding_shard_id)
else:
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight)
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight)
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment