Unverified Commit fd6204b2 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Lazy init] Force fall back to slow init for composite models (#11705)



* fix encoder-decoder & RAG

* finalize

* Update src/transformers/models/encoder_decoder/modeling_encoder_decoder.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Update src/transformers/models/rag/modeling_rag.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 5c1cda9d
...@@ -510,6 +510,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -510,6 +510,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
""" """
return None # Overwrite for models with output embeddings return None # Overwrite for models with output embeddings
def _init_weights(self, module):
"""
Initialize the weights. This method should be overridden by derived class.
"""
raise NotImplementedError(f"Make sure `_init_weigths` is implemented for {self.__class__}")
def tie_weights(self): def tie_weights(self):
""" """
Tie the weights between the input embeddings and the output embeddings. Tie the weights between the input embeddings and the output embeddings.
...@@ -1205,7 +1211,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1205,7 +1211,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model( model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model(
model, state_dict, pretrained_model_name_or_path model, state_dict, pretrained_model_name_or_path, _fast_init=_fast_init
) )
# make sure token embedding weights are still tied if needed # make sure token embedding weights are still tied if needed
...@@ -1225,7 +1231,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1225,7 +1231,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return model return model
@classmethod @classmethod
def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path): def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True):
# Convert old format to new format if needed from a PyTorch state_dict # Convert old format to new format if needed from a PyTorch state_dict
old_keys = [] old_keys = []
...@@ -1273,12 +1279,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1273,12 +1279,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
for pat in cls._keys_to_ignore_on_load_unexpected: for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
# tie unintialized modules if _fast_init:
# retrieve unintialized modules and initialize
unintialized_modules = model.retrieve_modules_from_names( unintialized_modules = model.retrieve_modules_from_names(
missing_keys, add_prefix=add_prefix, remove_prefix=remove_prefix missing_keys, add_prefix=add_prefix, remove_prefix=remove_prefix
) )
for module in unintialized_modules: for module in unintialized_modules:
model._init_weights(module) model._init_weights(module)
# copy state_dict so _load_from_state_dict can modify it # copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None) metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy() state_dict = state_dict.copy()
......
...@@ -221,6 +221,13 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -221,6 +221,13 @@ class EncoderDecoderModel(PreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
return self.decoder.set_output_embeddings(new_embeddings) return self.decoder.set_output_embeddings(new_embeddings)
@classmethod
def from_pretrained(cls, *args, **kwargs):
# At the moment fast initialization is not supported
# for composite models
kwargs["_fast_init"] = False
return super().from_pretrained(*args, **kwargs)
@classmethod @classmethod
def from_encoder_decoder_pretrained( def from_encoder_decoder_pretrained(
cls, cls,
......
...@@ -232,6 +232,13 @@ class RagPreTrainedModel(PreTrainedModel): ...@@ -232,6 +232,13 @@ class RagPreTrainedModel(PreTrainedModel):
base_model_prefix = "rag" base_model_prefix = "rag"
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
@classmethod
def from_pretrained(cls, *args, **kwargs):
# At the moment fast initialization is not supported
# for composite models
kwargs["_fast_init"] = False
return super().from_pretrained(*args, **kwargs)
@classmethod @classmethod
def from_pretrained_question_encoder_generator( def from_pretrained_question_encoder_generator(
cls, cls,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment