Unverified Commit 443f67e8 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[PyTorch] Refactor Resize Token Embeddings (#8880)

* fix resize tokens

* correct mobile_bert

* move embedding fix into modeling_utils.py

* refactor

* fix lm head resize

* refactor

* break lines to make sylvain happy

* add news tests

* fix typo

* improve test

* skip bart-like for now

* check if base_model = get(...) is necessary

* clean files

* improve test

* fix tests

* revert style templates

* Update templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
parent e52f9c0a
...@@ -688,6 +688,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -688,6 +688,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.pred_layer.proj return self.pred_layer.proj
def set_output_embeddings(self, new_embeddings):
self.pred_layer.proj = new_embeddings
def prepare_inputs_for_generation(self, input_ids, **kwargs): def prepare_inputs_for_generation(self, input_ids, **kwargs):
mask_token_id = self.config.mask_token_id mask_token_id = self.config.mask_token_id
lang_id = self.config.lang_id lang_id = self.config.lang_id
......
...@@ -1312,6 +1312,9 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1312,6 +1312,9 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_loss return self.lm_loss
def set_output_embeddings(self, new_embeddings):
self.lm_loss = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, use_mems=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, use_mems=None, **kwargs):
# Add dummy token at the end (no attention on this one) # Add dummy token at the end (no attention on this one)
......
...@@ -781,6 +781,9 @@ class {{cookiecutter.camelcase_modelname}}ForMaskedLM({{cookiecutter.camelcase_m ...@@ -781,6 +781,9 @@ class {{cookiecutter.camelcase_modelname}}ForMaskedLM({{cookiecutter.camelcase_m
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
......
...@@ -208,6 +208,10 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -208,6 +208,10 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
def test_resize_embeddings_untied(self):
pass
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
def test_tiny_model(self): def test_tiny_model(self):
......
...@@ -128,6 +128,10 @@ class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase): ...@@ -128,6 +128,10 @@ class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase):
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
pass pass
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
def test_resize_embeddings_untied(self):
pass
@unittest.skipUnless(torch_device != "cpu", "3B test too slow on CPU.") @unittest.skipUnless(torch_device != "cpu", "3B test too slow on CPU.")
@require_torch @require_torch
......
...@@ -815,6 +815,10 @@ class ModelTesterMixin: ...@@ -815,6 +815,10 @@ class ModelTesterMixin:
# Check that the model can still do a forward pass successfully (every parameter should be resized) # Check that the model can still do a forward pass successfully (every parameter should be resized)
# Input ids should be clamped to the maximum size of the vocabulary # Input ids should be clamped to the maximum size of the vocabulary
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1) inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
# make sure that decoder_input_ids are resized as well
if "decoder_input_ids" in inputs_dict:
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
model(**self._prepare_for_class(inputs_dict, model_class)) model(**self._prepare_for_class(inputs_dict, model_class))
# Check that adding and removing tokens has not modified the first part of the embedding matrix. # Check that adding and removing tokens has not modified the first part of the embedding matrix.
...@@ -825,6 +829,57 @@ class ModelTesterMixin: ...@@ -825,6 +829,57 @@ class ModelTesterMixin:
self.assertTrue(models_equal) self.assertTrue(models_equal)
def test_resize_embeddings_untied(self):
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
return
original_config.tie_word_embeddings = False
# if model cannot untied embeddings -> leave test
if original_config.tie_word_embeddings:
return
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
model = model_class(config).to(torch_device)
# if no output embeddings -> leave test
if model.get_output_embeddings() is None:
continue
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_vocab_size = config.vocab_size
model.resize_token_embeddings(model_vocab_size + 10)
self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
output_embeds = model.get_output_embeddings()
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
# Check bias if present
if output_embeds.bias is not None:
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model.resize_token_embeddings(model_vocab_size - 15)
self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
output_embeds = model.get_output_embeddings()
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
# Check bias if present
if output_embeds.bias is not None:
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
# Input ids should be clamped to the maximum size of the vocabulary
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
if "decoder_input_ids" in inputs_dict:
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
def test_model_common_attributes(self): def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -226,15 +226,9 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -226,15 +226,9 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
def test_tie_model_weights(self): def test_tie_model_weights(self):
pass pass
# def test_auto_model(self): @unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
# # XXX: add a tiny model to s3? def test_resize_embeddings_untied(self):
# model_name = "facebook/wmt19-ru-en-tiny" pass
# tiny = AutoModel.from_pretrained(model_name) # same vocab size
# tok = AutoTokenizer.from_pretrained(model_name) # same tokenizer
# inputs_dict = tok.batch_encode_plus(["Hello my friends"], return_tensors="pt")
# with torch.no_grad():
# tiny(**inputs_dict)
@require_torch @require_torch
......
...@@ -574,6 +574,10 @@ class ReformerTesterMixin: ...@@ -574,6 +574,10 @@ class ReformerTesterMixin:
# reformer cannot keep gradients in attentions or hidden states # reformer cannot keep gradients in attentions or hidden states
return return
def test_resize_embeddings_untied(self):
# reformer cannot resize embeddings that easily
return
@require_torch @require_torch
class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, ModelTesterMixin, unittest.TestCase): class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, ModelTesterMixin, unittest.TestCase):
......
...@@ -444,6 +444,20 @@ class T5ModelTester: ...@@ -444,6 +444,20 @@ class T5ModelTester:
) )
) )
def check_resize_embeddings_t5_v1_1(
self,
config,
):
prev_vocab_size = config.vocab_size
config.tie_word_embeddings = False
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
model.resize_token_embeddings(prev_vocab_size - 10)
self.parent.assertEqual(model.get_input_embeddings().weight.shape[0], prev_vocab_size - 10)
self.parent.assertEqual(model.get_output_embeddings().weight.shape[0], prev_vocab_size - 10)
self.parent.assertEqual(model.config.vocab_size, prev_vocab_size - 10)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -480,7 +494,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -480,7 +494,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
) )
test_pruning = False test_pruning = False
test_torchscript = True test_torchscript = True
test_resize_embeddings = False test_resize_embeddings = True
test_model_parallel = True test_model_parallel = True
is_encoder_decoder = True is_encoder_decoder = True
...@@ -536,6 +550,10 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -536,6 +550,10 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
def test_v1_1_resize_embeddings(self):
config = self.model_tester.prepare_config_and_inputs()[0]
self.model_tester.check_resize_embeddings_t5_v1_1(config)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -299,6 +299,10 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC ...@@ -299,6 +299,10 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
self.assertEqual(model_vocab_size, model.config.vocab_size) self.assertEqual(model_vocab_size, model.config.vocab_size)
self.assertEqual(model_embed.emb_layers[layer].weight.shape[0], cloned_embeddings[layer].shape[0]) self.assertEqual(model_embed.emb_layers[layer].weight.shape[0], cloned_embeddings[layer].shape[0])
def test_resize_embeddings_untied(self):
# transfo-xl requires special resize for lm-head
return
@require_torch @require_torch
class TransfoXLModelLanguageGenerationTest(unittest.TestCase): class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment