"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e32390931d32246915a1016e44e85e695b099040"
Unverified Commit 32e94cff authored by Deniz's avatar Deniz Committed by GitHub
Browse files

tf add resize_token_embeddings method (#4351)



* resize token embeddings

* add tokens

* add tokens

* add tokens

* add t5 token method

* add t5 token method

* add t5 token method

* typo

* debugging input

* debugging input

* debug

* debug

* debug

* trying to set embedding tokens properly

* set embeddings for generation head too

* set embeddings for generation head too

* debugging

* debugging

* enable generation

* add base method

* add base method

* add base method

* return logits in the main call

* reverting to generation

* revert back

* set embeddings for the bert main layer

* description

* fix conflicts

* logging

* set base model as self

* refactor

* tf_bert add method

* tf_bert add method

* tf_bert add method

* tf_bert add method

* tf_bert add method

* tf_bert add method

* tf_bert add method

* tf_bert add method

* v0

* v0

* finalize

* final

* black

* add tests

* revert back the emb call

* comments

* comments

* add the second test

* add vocab size condig

* add tf models

* add tf models. add common tests

* remove model specific embedding tests

* stylish

* remove files

* stylez

* Update src/transformers/modeling_tf_transfo_xl.py

change the error.
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* adding unchanged weight test
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 97343326
...@@ -60,6 +60,7 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer): ...@@ -60,6 +60,7 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config self.config = config
self.vocab_size = config.vocab_size
self.position_embeddings = tf.keras.layers.Embedding( self.position_embeddings = tf.keras.layers.Embedding(
config.max_position_embeddings, config.max_position_embeddings,
config.embedding_size, config.embedding_size,
...@@ -515,6 +516,10 @@ class TFAlbertMainLayer(tf.keras.layers.Layer): ...@@ -515,6 +516,10 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings return self.embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError raise NotImplementedError
......
...@@ -497,6 +497,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -497,6 +497,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.embeddings = TFBertEmbeddings(config, name="embeddings") self.embeddings = TFBertEmbeddings(config, name="embeddings")
...@@ -506,8 +507,9 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -506,8 +507,9 @@ class TFBertMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings return self.embeddings
def _resize_token_embeddings(self, new_num_tokens): def set_input_embeddings(self, value):
raise NotImplementedError self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
......
...@@ -213,6 +213,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -213,6 +213,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.w return self.w
def set_input_embeddings(self, value):
self.w.weight = value
self.w.vocab_size = value.shape[0]
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError raise NotImplementedError
......
...@@ -422,8 +422,9 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): ...@@ -422,8 +422,9 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings return self.embeddings
def _resize_token_embeddings(self, new_num_tokens): def set_input_embeddings(self, value):
raise NotImplementedError self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError raise NotImplementedError
......
...@@ -217,6 +217,10 @@ class TFElectraMainLayer(TFElectraPreTrainedModel): ...@@ -217,6 +217,10 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings return self.embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError raise NotImplementedError
...@@ -371,9 +375,6 @@ class TFElectraModel(TFElectraPreTrainedModel): ...@@ -371,9 +375,6 @@ class TFElectraModel(TFElectraPreTrainedModel):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.electra = TFElectraMainLayer(config, name="electra") self.electra = TFElectraMainLayer(config, name="electra")
def get_input_embeddings(self):
return self.electra.embeddings
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
...@@ -422,9 +423,6 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel): ...@@ -422,9 +423,6 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
self.electra = TFElectraMainLayer(config, name="electra") self.electra = TFElectraMainLayer(config, name="electra")
self.discriminator_predictions = TFElectraDiscriminatorPredictions(config, name="discriminator_predictions") self.discriminator_predictions = TFElectraDiscriminatorPredictions(config, name="discriminator_predictions")
def get_input_embeddings(self):
return self.electra.embeddings
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
def call( def call(
self, self,
...@@ -519,9 +517,6 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel): ...@@ -519,9 +517,6 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel):
self.activation = config.hidden_act self.activation = config.hidden_act
self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head") self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head")
def get_input_embeddings(self):
return self.electra.embeddings
def get_output_embeddings(self): def get_output_embeddings(self):
return self.generator_lm_head return self.generator_lm_head
......
...@@ -235,8 +235,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -235,8 +235,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.wte return self.wte
def _resize_token_embeddings(self, new_num_tokens): def set_input_embeddings(self, value):
raise NotImplementedError self.wte.weight = value
self.wte.vocab_size = self.wte.weight.shape[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
......
...@@ -227,8 +227,9 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -227,8 +227,9 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.tokens_embed return self.tokens_embed
def _resize_token_embeddings(self, new_num_tokens): def set_input_embeddings(self, value):
raise NotImplementedError self.tokens_embed.weight = value
self.tokens_embed.vocab_size = value.shape[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
......
...@@ -101,9 +101,6 @@ class TFRobertaMainLayer(TFBertMainLayer): ...@@ -101,9 +101,6 @@ class TFRobertaMainLayer(TFBertMainLayer):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
self.embeddings = TFRobertaEmbeddings(config, name="embeddings") self.embeddings = TFRobertaEmbeddings(config, name="embeddings")
def get_input_embeddings(self):
return self.embeddings
class TFRobertaPreTrainedModel(TFPreTrainedModel): class TFRobertaPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
......
...@@ -884,6 +884,16 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -884,6 +884,16 @@ class TFT5Model(TFT5PreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.shared return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder
...@@ -1011,6 +1021,15 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel): ...@@ -1011,6 +1021,15 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.shared return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder
......
...@@ -468,6 +468,9 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -468,6 +468,9 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.word_emb return self.word_emb
def set_input_embeddings(self, value):
raise NotImplementedError
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
return self.word_emb return self.word_emb
......
...@@ -199,6 +199,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -199,6 +199,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
else: else:
raise NotImplementedError raise NotImplementedError
def set_input_embeddings(self, value):
"""
Set model's input embeddings
Args:
value (:obj:`tf.keras.layers.Layer`):
A module mapping vocabulary to hidden states.
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
base_model.set_input_embeddings(value)
else:
raise NotImplementedError
def get_output_embeddings(self): def get_output_embeddings(self):
""" """
Returns the model's output embeddings. Returns the model's output embeddings.
...@@ -209,10 +223,50 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -209,10 +223,50 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
""" """
return None # Overwrite for models with output embeddings return None # Overwrite for models with output embeddings
def resize_token_embeddings(self, new_num_tokens=None):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
Arguments:
new_num_tokens: (`optional`) int:
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
If not provided or None: does nothing and just returns a pointer to the input tokens ``tf.Variable`` Module of the model.
Return: ``tf.Variable``
Pointer to the input tokens Embeddings Module of the model
"""
model_embeds = self._resize_token_embeddings(new_num_tokens)
if new_num_tokens is None:
return model_embeds
return model_embeds
def _resize_token_embeddings(self, new_num_tokens):
# get_input_embeddings and set_input_embeddings need to be implemented in base layer.
base_model = getattr(self, self.base_model_prefix, self)
old_embeddings = base_model.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
base_model.set_input_embeddings(new_embeddings)
# Update base model and current model config
self.config.vocab_size = new_num_tokens
base_model.vocab_size = new_num_tokens
return base_model.get_input_embeddings()
def _get_word_embeddings(self, embeddings):
if hasattr(embeddings, "word_embeddings"):
# TFBertEmbeddings, TFAlbertEmbeddings, TFElectraEmbeddings
return embeddings.word_embeddings
elif hasattr(embeddings, "weight"):
# TFSharedEmbeddings
return embeddings.weight
else:
raise ValueError("word embedding is not defined.")
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
""" Build a resized Embedding Variable from a provided token Embedding Module. """ Build a resized Embedding Variable from a provided token Embedding Module.
Increasing the size will add newly initialized vectors at the end Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end Reducing the size will remove vectors from the end.
Args: Args:
new_num_tokens: (`optional`) int new_num_tokens: (`optional`) int
...@@ -221,42 +275,32 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -221,42 +275,32 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
Reducing the size will remove vectors from the end Reducing the size will remove vectors from the end
If not provided or None: return the provided token Embedding Module. If not provided or None: return the provided token Embedding Module.
Return: ``tf.Variable`` Return: ``tf.Variable``
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None Pointer to the resized word Embedding Module or the old Embedding Module if new_num_tokens is None
""" """
# if new_num_tokens is None: word_embeddings = self._get_word_embeddings(old_embeddings)
# return old_embeddings if new_num_tokens is None:
return word_embeddings
# old_num_tokens, old_embedding_dim = old_embeddings.weight.size() old_num_tokens, old_embedding_dim = word_embeddings.shape
# if old_num_tokens == new_num_tokens: if old_num_tokens == new_num_tokens:
# return old_embeddings return word_embeddings
# # Build new embeddings # initialize new embeddings
# new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) # todo: initializer range is not always passed in config.
# new_embeddings.to(old_embeddings.weight.device) init_range = getattr(self.config, "initializer_range", 0.02)
new_embeddings = self.add_weight(
# # initialize all new embeddings (in particular added tokens) "weight",
# self._init_weights(new_embeddings) shape=[new_num_tokens, old_embedding_dim],
initializer=get_initializer(init_range),
# # Copy token embeddings from the previous weights dtype=tf.float32,
# num_tokens_to_copy = min(old_num_tokens, new_num_tokens) )
# new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :] init_weights = new_embeddings.numpy()
# return new_embeddings
def resize_token_embeddings(self, new_num_tokens=None):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
Arguments:
new_num_tokens: (`optional`) int: # Copy token embeddings from the previous weights
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
If not provided or None: does nothing and just returns a pointer to the input tokens ``tf.Variable`` Module of the model. init_weights[:num_tokens_to_copy] = word_embeddings[:num_tokens_to_copy, :]
new_embeddings.assign(init_weights)
Return: ``tf.Variable`` return new_embeddings
Pointer to the input tokens Embeddings Module of the model
"""
raise NotImplementedError
def prune_heads(self, heads_to_prune): def prune_heads(self, heads_to_prune):
""" Prunes heads of the base model. """ Prunes heads of the base model.
......
...@@ -306,6 +306,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -306,6 +306,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings return self.embeddings
def set_input_embeddings(self, value):
self.embeddings.weight = value
self.embeddings.vocab_size = value.shape[0]
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError raise NotImplementedError
......
...@@ -388,6 +388,10 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -388,6 +388,10 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.word_embedding return self.word_embedding
def set_input_embeddings(self, value):
self.word_embedding.weight = value
self.word_embedding.vocab_size = value.shape[0]
def build(self, input_shape): def build(self, input_shape):
initializer = get_initializer(self.initializer_range) initializer = get_initializer(self.initializer_range)
self.mask_emb = self.add_weight( self.mask_emb = self.add_weight(
......
...@@ -472,6 +472,30 @@ class TFModelTesterMixin: ...@@ -472,6 +472,30 @@ class TFModelTesterMixin:
model(inputs) model(inputs)
def test_resize_token_embeddings(self):
if not self.test_resize_embeddings:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
INPUT_SHAPE = [1, 10, config.hidden_size]
for model_class in self.all_model_classes:
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
# build the embeddings
model = model_class(config=config)
emb_old = model.get_input_embeddings()
emb_old.build(INPUT_SHAPE)
# reshape the embeddings
new_embeddings = model._get_resized_embeddings(emb_old, size)
# # check that the the resized embeddings size matches the desired size.
assert_size = size if size is not None else config.vocab_size
self.assertEqual(new_embeddings.shape[0], assert_size)
# check that weights remain the same after resizing
emd_old_weights = model._get_word_embeddings(emb_old)
models_equal = True
for p1, p2 in zip(emd_old_weights.numpy(), new_embeddings.numpy()):
if np.sum(abs(p1 - p2)) > 0:
models_equal = False
self.assertTrue(models_equal)
def test_lm_head_model_random_no_beam_search_generate(self): def test_lm_head_model_random_no_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"] input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
......
...@@ -169,7 +169,6 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -169,7 +169,6 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
) )
test_pruning = True test_pruning = True
test_torchscript = True test_torchscript = True
test_resize_embeddings = True
test_head_masking = True test_head_masking = True
def setUp(self): def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment