Unverified Commit 14385c80 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix weight mapping test for Transfomers v5 (#33162)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 76139d08
...@@ -30,7 +30,12 @@ def create_dummy_model(repo: str, model_arch: str) -> PreTrainedModel: ...@@ -30,7 +30,12 @@ def create_dummy_model(repo: str, model_arch: str) -> PreTrainedModel:
model_cls: PreTrainedModel = getattr(transformers, model_arch) model_cls: PreTrainedModel = getattr(transformers, model_arch)
config = AutoConfig.from_pretrained(repo) config = AutoConfig.from_pretrained(repo)
with torch.device("meta"): with torch.device("meta"):
return model_cls._from_config(config) model = model_cls._from_config(config)
# TODO(hmellor): Remove this once Transformers has fixed tied weights on meta device
# https://github.com/huggingface/transformers/issues/43522
if getattr(config.get_text_config(), "tie_word_embeddings", False):
model.tie_weights()
return model
def model_architectures_for_test() -> list[str]: def model_architectures_for_test() -> list[str]:
......
...@@ -249,7 +249,8 @@ class Base( ...@@ -249,7 +249,8 @@ class Base(
# Layers before module list # Layers before module list
for name in pp_plan[:module_list_idx]: for name in pp_plan[:module_list_idx]:
if self.pp_group.is_first_rank or ( if self.pp_group.is_first_rank or (
self.text_config.tie_word_embeddings and self.pp_group.is_last_rank getattr(self.text_config, "tie_word_embeddings", False)
and self.pp_group.is_last_rank
): ):
continue continue
setattr(self.model, name, PPMissingLayer()) setattr(self.model, name, PPMissingLayer())
......
...@@ -38,7 +38,8 @@ class CausalMixin(VllmModelForTextGeneration): ...@@ -38,7 +38,8 @@ class CausalMixin(VllmModelForTextGeneration):
# Tell `Base.load_weights` to skip # Tell `Base.load_weights` to skip
# `lm_head` if the model has tied word embeddings # `lm_head` if the model has tied word embeddings
if self.text_config.tie_word_embeddings: tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False)
if tie_word_embeddings:
self.skip_prefixes.append("lm_head.") self.skip_prefixes.append("lm_head.")
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
...@@ -48,7 +49,7 @@ class CausalMixin(VllmModelForTextGeneration): ...@@ -48,7 +49,7 @@ class CausalMixin(VllmModelForTextGeneration):
quant_config=self.quant_config, quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
if self.text_config.tie_word_embeddings: if tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights( self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings() self.model.get_input_embeddings()
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment