"...tools/nnictl/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "d165905d0ba24cfba414b8e0c20fa8d7c8ab6a6e"
Unverified Commit ae0c27ad authored by tom-p-reichel's avatar tom-p-reichel Committed by GitHub
Browse files

don't initialize the output embeddings if we're going to tie them to input embeddings (#28192)

* test that tied output embeddings aren't initialized on load

* don't initialize the output embeddings if we're going to tie them to the input embeddings
parent a937425e
......@@ -3746,6 +3746,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
_loaded_keys = loaded_keys
not_initialized_submodules = set_initialized_submodules(model, _loaded_keys)
# if we're about to tie the output embeds to the input embeds we don't need to init them
if hasattr(model.config, "tie_word_embeddings") and model.config.tie_word_embeddings:
output_embeddings = model.get_output_embeddings()
if output_embeddings is not None:
output_embeddings._is_hf_initialized = True
else:
not_initialized_submodules = dict(model.named_modules())
# This will only initialize submodules that are not marked as initialized by the line above.
......
......@@ -483,6 +483,40 @@ class ModelTesterMixin:
max_diff = torch.max(torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]))
self.assertLessEqual(max_diff.item(), 1e-3, msg=f"{key} not identical")
def test_fast_init_tied_embeddings(self):
class MyClass(PreTrainedModel):
config_class = PretrainedConfig
_tied_weights_keys = ["output_embeddings.weight"]
def __init__(self, config=None):
super().__init__(config if config is not None else PretrainedConfig())
self.input_embeddings = nn.Embedding(10, 10)
self.output_embeddings = nn.Linear(10, 10, bias=False)
self.tie_weights()
def get_output_embeddings(self):
return self.output_embeddings
def set_output_embeddings(self, output_embeddings):
self.output_embeddings = output_embeddings
def get_input_embeddings(self):
return self.input_embeddings
def set_input_embeddings(self, input_embeddings):
self.input_embeddings = input_embeddings
def _init_weights(self, module):
if module is self.output_embeddings:
raise ValueError("unnecessarily initialized tied output embedding!")
model = MyClass()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# throws if it initializes the tied output_embeddings
MyClass.from_pretrained(tmpdirname)
def test_save_load_fast_init_to_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment