Unverified Commit bf7e3c51 authored by Jonghyun Choe's avatar Jonghyun Choe Committed by GitHub
Browse files

[Model] use AutoWeightsLoader for baichuan, gpt-neox, mpt (#15939)


Signed-off-by: default avatarJonghyun Choe <andy.choe729@gmail.com>
parent a35a8a83
...@@ -47,7 +47,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -47,7 +47,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
from .utils import (is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers)
...@@ -321,6 +321,45 @@ class BaiChuanModel(nn.Module): ...@@ -321,6 +321,45 @@ class BaiChuanModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
SupportsQuant): SupportsQuant):
...@@ -353,6 +392,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, ...@@ -353,6 +392,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
self.lm_head.weight.weight_loader = self.lm_head_weight_loader
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
...@@ -393,17 +433,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, ...@@ -393,17 +433,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ loader = AutoWeightsLoader(self)
# (param_name, shard_name, shard_id) return loader.load_weights(weights)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), def lm_head_weight_loader(self, param: nn.Parameter,
] loaded_weight: torch.Tensor):
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if name == "lm_head.weight":
# Unlike Baichuan, Baichuan2 normalizes the head weights. # Unlike Baichuan, Baichuan2 normalizes the head weights.
# Refer to: # Refer to:
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508 # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
...@@ -412,34 +446,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, ...@@ -412,34 +446,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
# https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704 # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
is_baichuan2 = self.config.vocab_size == 125696 is_baichuan2 = self.config.vocab_size == 125696
if is_baichuan2: if is_baichuan2:
loaded_weight = torch.nn.functional.normalize( loaded_weight = torch.nn.functional.normalize(loaded_weight)
loaded_weight)
for (param_name, weight_name, shard_id) in stacked_params_mapping: default_weight_loader(param, loaded_weight)
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
......
...@@ -42,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -42,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -241,6 +241,45 @@ class GPTNeoXModel(nn.Module): ...@@ -241,6 +241,45 @@ class GPTNeoXModel(nn.Module):
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name):
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using OpenRLHF may include
# these tensors in the checkpoint. Skip them.
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if "query_key_value" in name:
# NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class GPTNeoXForCausalLM(nn.Module, SupportsPP): class GPTNeoXForCausalLM(nn.Module, SupportsPP):
...@@ -297,39 +336,5 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): ...@@ -297,39 +336,5 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters()) loader = AutoWeightsLoader(self)
loaded_params: Set[str] = set() return loader.load_weights(weights)
for name, loaded_weight in weights:
if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name):
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using OpenRLHF may include
# these tensors in the checkpoint. Skip them.
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if "query_key_value" in name:
# NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -27,7 +27,7 @@ from vllm.sequence import IntermediateTensors ...@@ -27,7 +27,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.mpt import MPTConfig
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -266,6 +266,23 @@ class MPTModel(nn.Module): ...@@ -266,6 +266,23 @@ class MPTModel(nn.Module):
hidden_states = self.norm_f(hidden_states) hidden_states = self.norm_f(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class MPTForCausalLM(nn.Module, SupportsPP): class MPTForCausalLM(nn.Module, SupportsPP):
...@@ -318,17 +335,5 @@ class MPTForCausalLM(nn.Module, SupportsPP): ...@@ -318,17 +335,5 @@ class MPTForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False)) loader = AutoWeightsLoader(self)
loaded_params: Set[str] = set() return loader.load_weights(weights)
for name, loaded_weight in weights:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment