Commit 600a4232 authored by thomwolf's avatar thomwolf
Browse files

add weights tying, attention and hidden states output tests

parent 04d2006f
...@@ -141,7 +141,9 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -141,7 +141,9 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
""" """
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super(TFBertEmbeddings, self).__init__(**kwargs) super(TFBertEmbeddings, self).__init__(**kwargs)
self.word_embeddings = tf.keras.layers.Embedding(config.vocab_size, config.hidden_size, name='word_embeddings') self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
self.position_embeddings = tf.keras.layers.Embedding(config.max_position_embeddings, config.hidden_size, name='position_embeddings') self.position_embeddings = tf.keras.layers.Embedding(config.max_position_embeddings, config.hidden_size, name='position_embeddings')
self.token_type_embeddings = tf.keras.layers.Embedding(config.type_vocab_size, config.hidden_size, name='token_type_embeddings') self.token_type_embeddings = tf.keras.layers.Embedding(config.type_vocab_size, config.hidden_size, name='token_type_embeddings')
...@@ -150,8 +152,44 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -150,8 +152,44 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm') self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def build(self, input_shape):
"""Build shared word embedding layer """
with tf.name_scope("word_embeddings"):
# Create and initialize weights. The random normal initializer was chosen
# arbitrarily, and works well.
self.word_embeddings = self.add_weight(
"weight",
shape=[self.vocab_size, self.hidden_size],
initializer=tf.random_normal_initializer(
mean=0., stddev=self.hidden_size**-0.5))
super(TFBertEmbeddings, self).build(input_shape)
@tf.function @tf.function
def call(self, inputs, training=False): def call(self, inputs, mode="embedding", training=False):
"""Get token embeddings of inputs.
Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
mode: string, a valid value is one of "embedding" and "linear".
Returns:
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
linear tensor, float32 with shape [batch_size, length, vocab_size].
Raises:
ValueError: if mode is not valid.
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
if mode == "embedding":
return self._embedding(inputs, training=training)
elif mode == "linear":
return self._linear(inputs)
else:
raise ValueError("mode {} is not valid.".format(mode))
def _embedding(self, inputs, training=False):
"""Applies embedding based on inputs tensor."""
# Create binary mask of size [batch_size, length]
input_ids, position_ids, token_type_ids = inputs input_ids, position_ids, token_type_ids = inputs
seq_length = tf.shape(input_ids)[1] seq_length = tf.shape(input_ids)[1]
...@@ -160,7 +198,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -160,7 +198,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
if token_type_ids is None: if token_type_ids is None:
token_type_ids = tf.fill(tf.shape(input_ids), 0) token_type_ids = tf.fill(tf.shape(input_ids), 0)
words_embeddings = self.word_embeddings(input_ids) words_embeddings = tf.gather(self.word_embeddings, input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)
...@@ -170,6 +208,21 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -170,6 +208,21 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
return embeddings return embeddings
def _linear(self, inputs):
"""Computes logits by running inputs through a linear layer.
Args:
inputs: A float32 tensor with shape [batch_size, length, hidden_size]
Returns:
float32 tensor with shape [batch_size, length, vocab_size].
"""
batch_size = tf.shape(inputs)[0]
length = tf.shape(inputs)[1]
x = tf.reshape(inputs, [-1, self.hidden_size])
logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
return tf.reshape(logits, [batch_size, length, self.vocab_size])
class TFBertSelfAttention(tf.keras.layers.Layer): class TFBertSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
...@@ -448,8 +501,6 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -448,8 +501,6 @@ class TFBertMainLayer(tf.keras.layers.Layer):
self.encoder = TFBertEncoder(config, name='encoder') self.encoder = TFBertEncoder(config, name='encoder')
self.pooler = TFBertPooler(config, name='pooler') self.pooler = TFBertPooler(config, name='pooler')
# self.apply(self.init_weights) # TODO check weights initialization
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError raise NotImplementedError
...@@ -692,22 +743,14 @@ class TFBertForPreTraining(TFBertPreTrainedModel): ...@@ -692,22 +743,14 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
super(TFBertForPreTraining, self).__init__(config) super(TFBertForPreTraining, self).__init__(config)
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
self.cls_mlm = TFBertMLMHead(config, name='cls_mlm')
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp') self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
"""
pass # TODO add weights tying
@tf.function @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, training=training)
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
prediction_scores = self.cls_mlm(sequence_output) prediction_scores = self.bert.embeddings(sequence_output, mode="linear", training=training)
seq_relationship_score = self.cls_nsp(pooled_output) seq_relationship_score = self.cls_nsp(pooled_output)
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
...@@ -751,21 +794,13 @@ class TFBertForMaskedLM(TFBertPreTrainedModel): ...@@ -751,21 +794,13 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
super(TFBertForMaskedLM, self).__init__(config) super(TFBertForMaskedLM, self).__init__(config)
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
self.cls_mlm = TFBertMLMHead(config, name='cls_mlm')
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
"""
pass # TODO add weights tying
@tf.function @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, training=training)
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.cls_mlm(sequence_output) prediction_scores = self.bert.embeddings(sequence_output, mode="linear", training=training)
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
......
...@@ -64,7 +64,7 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -64,7 +64,7 @@ class TFPreTrainedModel(tf.keras.Model):
self.config = config self.config = config
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
""" Build a resized Embedding Module from a provided token Embedding Module. """ Build a resized Embedding Variable from a provided token Embedding Module.
Increasing the size will add newly initialized vectors at the end Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end Reducing the size will remove vectors from the end
...@@ -77,12 +77,25 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -77,12 +77,25 @@ class TFPreTrainedModel(tf.keras.Model):
Return: ``torch.nn.Embeddings`` Return: ``torch.nn.Embeddings``
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
""" """
raise NotImplementedError # if new_num_tokens is None:
# return old_embeddings
def _tie_or_clone_weights(self, first_module, second_module): # old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
""" Tie or clone module weights depending of weither we are using TorchScript or not # if old_num_tokens == new_num_tokens:
""" # return old_embeddings
raise NotImplementedError
# # Build new embeddings
# new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
# new_embeddings.to(old_embeddings.weight.device)
# # initialize all new embeddings (in particular added tokens)
# self._init_weights(new_embeddings)
# # Copy word embeddings from the previous weights
# num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
# new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
# return new_embeddings
def resize_token_embeddings(self, new_num_tokens=None): def resize_token_embeddings(self, new_num_tokens=None):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
......
...@@ -64,44 +64,40 @@ class TFCommonTestCases: ...@@ -64,44 +64,40 @@ class TFCommonTestCases:
def test_attention_outputs(self): def test_attention_outputs(self):
pass config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# for model_class in self.all_model_classes: config.output_attentions = True
# config.output_attentions = True config.output_hidden_states = False
# config.output_hidden_states = False model = model_class(config)
# model = model_class(config) outputs = model(inputs_dict)
# model.eval() attentions = [t.numpy() for t in outputs[-1]]
# outputs = model(**inputs_dict) self.assertEqual(model.config.output_attentions, True)
# attentions = outputs[-1] self.assertEqual(model.config.output_hidden_states, False)
# self.assertEqual(model.config.output_attentions, True) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# self.assertEqual(model.config.output_hidden_states, False) self.assertListEqual(
# self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) list(attentions[0].shape[-3:]),
# self.assertListEqual( [self.model_tester.num_attention_heads,
# list(attentions[0].shape[-3:]), self.model_tester.seq_length,
# [self.model_tester.num_attention_heads, self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
# self.model_tester.seq_length, out_len = len(outputs)
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
# out_len = len(outputs) # Check attention is always last and order is fine
config.output_attentions = True
# # Check attention is always last and order is fine config.output_hidden_states = True
# config.output_attentions = True model = model_class(config)
# config.output_hidden_states = True outputs = model(inputs_dict)
# model = model_class(config) self.assertEqual(out_len+1, len(outputs))
# model.eval() self.assertEqual(model.config.output_attentions, True)
# outputs = model(**inputs_dict) self.assertEqual(model.config.output_hidden_states, True)
# self.assertEqual(out_len+1, len(outputs))
# self.assertEqual(model.config.output_attentions, True) attentions = [t.numpy() for t in outputs[-1]]
# self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
# attentions = outputs[-1] list(attentions[0].shape[-3:]),
# self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) [self.model_tester.num_attention_heads,
# self.assertListEqual( self.model_tester.seq_length,
# list(attentions[0].shape[-3:]), self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
# [self.model_tester.num_attention_heads,
# self.model_tester.seq_length,
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
def test_headmasking(self): def test_headmasking(self):
pass pass
...@@ -178,22 +174,20 @@ class TFCommonTestCases: ...@@ -178,22 +174,20 @@ class TFCommonTestCases:
def test_hidden_states_output(self): def test_hidden_states_output(self):
pass config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# for model_class in self.all_model_classes: config.output_hidden_states = True
# config.output_hidden_states = True config.output_attentions = False
# config.output_attentions = False model = model_class(config)
# model = model_class(config) outputs = model(inputs_dict)
# model.eval() hidden_states = [t.numpy() for t in outputs[-1]]
# outputs = model(**inputs_dict) self.assertEqual(model.config.output_attentions, False)
# hidden_states = outputs[-1] self.assertEqual(model.config.output_hidden_states, True)
# self.assertEqual(model.config.output_attentions, False) self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
# self.assertEqual(model.config.output_hidden_states, True) self.assertListEqual(
# self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1) list(hidden_states[0].shape[-2:]),
# self.assertListEqual( [self.model_tester.seq_length, self.model_tester.hidden_size])
# list(hidden_states[0].shape[-2:]),
# [self.model_tester.seq_length, self.model_tester.hidden_size])
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment