Unverified Commit 7ac48fd3 authored by grYe99's avatar grYe99 Committed by GitHub
Browse files

[Model] Add AutoWeightsLoader support for jais (#38074)


Signed-off-by: default avatargrYe99 <guorongye99@gmail.com>
Co-authored-by: default avatargrYe99 <guorongye99@gmail.com>
parent d6bb2a9d
...@@ -53,6 +53,7 @@ from vllm.transformers_utils.configs.jais import JAISConfig ...@@ -53,6 +53,7 @@ from vllm.transformers_utils.configs.jais import JAISConfig
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import ( from .utils import (
AutoWeightsLoader,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_empty_intermediate_tensors_factory,
make_layers, make_layers,
...@@ -311,6 +312,35 @@ class JAISModel(nn.Module): ...@@ -311,6 +312,35 @@ class JAISModel(nn.Module):
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if "relative_pe" in name:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class JAISLMHeadModel(nn.Module, SupportsPP): class JAISLMHeadModel(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
...@@ -364,36 +394,8 @@ class JAISLMHeadModel(nn.Module, SupportsPP): ...@@ -364,36 +394,8 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False)) loader = AutoWeightsLoader(
loaded_params: set[str] = set() self,
for name, loaded_weight in weights: skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
if "lm_head.weight" in name: )
# GPT-2 ties the weights of the embedding layer and the final return loader.load_weights(weights)
# linear layer.
continue
if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if "relative_pe" in name:
continue
if not name.startswith("transformer."):
name = "transformer." + name
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment