Unverified Commit e1f5a71e authored by Calvin Chen's avatar Calvin Chen Committed by GitHub
Browse files

[Model] use AutoWeightsLoader for bloom (#18300)


Signed-off-by: default avatarcalvin chen <120380290@qq.com>
parent f4a8a374
......@@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only
from .utils import (is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
......@@ -229,6 +229,7 @@ class BloomModel(nn.Module):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_dim = config.hidden_size
......@@ -278,6 +279,38 @@ class BloomModel(nn.Module):
hidden_states = self.ln_f(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if "query_key_value" in name:
# NOTE: BLOOM's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
......@@ -325,35 +358,15 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name == "lm_head.weight":
continue
if not name.startswith("transformer."):
name = "transformer." + name
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if "query_key_value" in name:
# NOTE: BLOOM's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.weight"])
weights = _add_transformer_prefix(weights)
return loader.load_weights(weights)
def _add_transformer_prefix(
weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
for name, tensor in weights:
if not name.startswith('transformer.'):
name = 'transformer.' + name
yield name, tensor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment