Unverified Commit 9996f697 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix saved model creation (#5468)

* Fix TF Serving when output_hidden_states and output_attentions are True

* Add tests for saved model creation + bug fix for multiple choices models

* remove unused import

* Fix the input for several layers

* Fix test

* Fix conflict printing

* Apply style

* Fix XLM and Flaubert for TensorFlow

* Apply style

* Fix TF check version

* Apply style

* Trigger CI
parent 5a0dac53
...@@ -35,7 +35,6 @@ from .modeling_tf_utils import ( ...@@ -35,7 +35,6 @@ from .modeling_tf_utils import (
TFQuestionAnsweringLoss, TFQuestionAnsweringLoss,
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFTokenClassificationLoss, TFTokenClassificationLoss,
cast_bool_to_primitive,
get_initializer, get_initializer,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -99,7 +98,15 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer): ...@@ -99,7 +98,15 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
) )
super().build(input_shape) super().build(input_shape)
def call(self, inputs, mode="embedding", training=False): def call(
self,
input_ids=None,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
mode="embedding",
training=False,
):
"""Get token embeddings of inputs. """Get token embeddings of inputs.
Args: Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids) inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
...@@ -115,15 +122,15 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer): ...@@ -115,15 +122,15 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
""" """
if mode == "embedding": if mode == "embedding":
return self._embedding(inputs, training=training) return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
elif mode == "linear": elif mode == "linear":
return self._linear(inputs) return self._linear(input_ids)
else: else:
raise ValueError("mode {} is not valid.".format(mode)) raise ValueError("mode {} is not valid.".format(mode))
def _embedding(self, inputs, training=False): def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
"""Applies embedding based on inputs tensor.""" """Applies embedding based on inputs tensor."""
input_ids, position_ids, token_type_ids, inputs_embeds = inputs assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(input_ids)
...@@ -175,6 +182,7 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer): ...@@ -175,6 +182,7 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer):
), f"Hidden size {config.hidden_size} not dividable by number of heads {config.num_attention_heads}" ), f"Hidden size {config.hidden_size} not dividable by number of heads {config.num_attention_heads}"
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
self.output_attentions = config.output_attentions
self.query = tf.keras.layers.Dense( self.query = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
...@@ -192,9 +200,7 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer): ...@@ -192,9 +200,7 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3]) return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs, training=False): def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
hidden_states, attention_mask, head_mask, output_attentions = inputs
batch_size = shape_list(hidden_states)[0] batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states) mixed_key_layer = self.key(hidden_states)
...@@ -233,9 +239,7 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer): ...@@ -233,9 +239,7 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer):
context_layer, (batch_size, -1, self.all_head_size) context_layer, (batch_size, -1, self.all_head_size)
) # (batch_size, seq_len_q, all_head_size) ) # (batch_size, seq_len_q, all_head_size)
outputs = ( outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
(context_layer, attention_probs) if cast_bool_to_primitive(output_attentions) is True else (context_layer,)
)
return outputs return outputs
...@@ -248,9 +252,7 @@ class TFAlbertSelfOutput(tf.keras.layers.Layer): ...@@ -248,9 +252,7 @@ class TFAlbertSelfOutput(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, inputs, training=False): def call(self, hidden_states, input_tensor, training=False):
hidden_states, input_tensor = inputs
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -262,6 +264,7 @@ class TFAlbertAttention(TFBertSelfAttention): ...@@ -262,6 +264,7 @@ class TFAlbertAttention(TFBertSelfAttention):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.output_attentions = config.output_attentions
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
) )
...@@ -271,9 +274,7 @@ class TFAlbertAttention(TFBertSelfAttention): ...@@ -271,9 +274,7 @@ class TFAlbertAttention(TFBertSelfAttention):
def prune_heads(self, heads): def prune_heads(self, heads):
raise NotImplementedError raise NotImplementedError
def call(self, inputs, training=False): def call(self, input_tensor, attention_mask, head_mask, output_attentions, training=False):
input_tensor, attention_mask, head_mask, output_attentions = inputs
batch_size = shape_list(input_tensor)[0] batch_size = shape_list(input_tensor)[0]
mixed_query_layer = self.query(input_tensor) mixed_query_layer = self.query(input_tensor)
mixed_key_layer = self.key(input_tensor) mixed_key_layer = self.key(input_tensor)
...@@ -312,9 +313,7 @@ class TFAlbertAttention(TFBertSelfAttention): ...@@ -312,9 +313,7 @@ class TFAlbertAttention(TFBertSelfAttention):
context_layer, (batch_size, -1, self.all_head_size) context_layer, (batch_size, -1, self.all_head_size)
) # (batch_size, seq_len_q, all_head_size) ) # (batch_size, seq_len_q, all_head_size)
self_outputs = ( self_outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
(context_layer, attention_probs) if cast_bool_to_primitive(output_attentions) is True else (context_layer,)
)
hidden_states = self_outputs[0] hidden_states = self_outputs[0]
...@@ -349,11 +348,9 @@ class TFAlbertLayer(tf.keras.layers.Layer): ...@@ -349,11 +348,9 @@ class TFAlbertLayer(tf.keras.layers.Layer):
) )
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, inputs, training=False): def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
hidden_states, attention_mask, head_mask, output_attentions = inputs
attention_outputs = self.attention( attention_outputs = self.attention(
[hidden_states, attention_mask, head_mask, output_attentions], training=training hidden_states, attention_mask, head_mask, output_attentions, training=training
) )
ffn_output = self.ffn(attention_outputs[0]) ffn_output = self.ffn(attention_outputs[0])
ffn_output = self.activation(ffn_output) ffn_output = self.activation(ffn_output)
...@@ -371,32 +368,32 @@ class TFAlbertLayerGroup(tf.keras.layers.Layer): ...@@ -371,32 +368,32 @@ class TFAlbertLayerGroup(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.albert_layers = [ self.albert_layers = [
TFAlbertLayer(config, name="albert_layers_._{}".format(i)) for i in range(config.inner_group_num) TFAlbertLayer(config, name="albert_layers_._{}".format(i)) for i in range(config.inner_group_num)
] ]
def call(self, inputs, training=False): def call(self, hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states, training=False):
hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states = inputs
layer_hidden_states = () layer_hidden_states = ()
layer_attentions = () layer_attentions = ()
for layer_index, albert_layer in enumerate(self.albert_layers): for layer_index, albert_layer in enumerate(self.albert_layers):
layer_output = albert_layer( layer_output = albert_layer(
[hidden_states, attention_mask, head_mask[layer_index], output_attentions], training=training hidden_states, attention_mask, head_mask[layer_index], output_attentions, training=training
) )
hidden_states = layer_output[0] hidden_states = layer_output[0]
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
layer_attentions = layer_attentions + (layer_output[1],) layer_attentions = layer_attentions + (layer_output[1],)
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
layer_hidden_states = layer_hidden_states + (hidden_states,) layer_hidden_states = layer_hidden_states + (hidden_states,)
outputs = (hidden_states,) outputs = (hidden_states,)
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
outputs = outputs + (layer_hidden_states,) outputs = outputs + (layer_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
outputs = outputs + (layer_attentions,) outputs = outputs + (layer_attentions,)
# last-layer hidden state, (layer hidden states), (layer attentions) # last-layer hidden state, (layer hidden states), (layer attentions)
return outputs return outputs
...@@ -417,13 +414,11 @@ class TFAlbertTransformer(tf.keras.layers.Layer): ...@@ -417,13 +414,11 @@ class TFAlbertTransformer(tf.keras.layers.Layer):
for i in range(config.num_hidden_groups) for i in range(config.num_hidden_groups)
] ]
def call(self, inputs, training=False): def call(self, hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states, training=False):
hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states = inputs
hidden_states = self.embedding_hidden_mapping_in(hidden_states) hidden_states = self.embedding_hidden_mapping_in(hidden_states)
all_attentions = () all_attentions = ()
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
all_hidden_states = (hidden_states,) all_hidden_states = (hidden_states,)
for i in range(self.config.num_hidden_layers): for i in range(self.config.num_hidden_layers):
...@@ -434,27 +429,25 @@ class TFAlbertTransformer(tf.keras.layers.Layer): ...@@ -434,27 +429,25 @@ class TFAlbertTransformer(tf.keras.layers.Layer):
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups)) group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
layer_group_output = self.albert_layer_groups[group_idx]( layer_group_output = self.albert_layer_groups[group_idx](
[ hidden_states,
hidden_states, attention_mask,
attention_mask, head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group], output_attentions,
output_attentions, output_hidden_states,
output_hidden_states,
],
training=training, training=training,
) )
hidden_states = layer_group_output[0] hidden_states = layer_group_output[0]
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
all_attentions = all_attentions + layer_group_output[-1] all_attentions = all_attentions + layer_group_output[-1]
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,) outputs = (hidden_states,)
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
# last-layer hidden state, (all hidden states), (all attentions) # last-layer hidden state, (all hidden states), (all attentions)
...@@ -619,9 +612,13 @@ class TFAlbertMainLayer(tf.keras.layers.Layer): ...@@ -619,9 +612,13 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
head_mask = [None] * self.num_hidden_layers head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers) # head_mask = tf.constant([0] * self.num_hidden_layers)
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training) embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
[embedding_output, extended_attention_mask, head_mask, output_attentions, output_hidden_states], embedding_output,
extended_attention_mask,
head_mask,
output_attentions,
output_hidden_states,
training=training, training=training,
) )
...@@ -1274,7 +1271,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1274,7 +1271,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_inputs = [ outputs = self.albert(
flat_input_ids, flat_input_ids,
flat_attention_mask, flat_attention_mask,
flat_token_type_ids, flat_token_type_ids,
...@@ -1283,9 +1280,8 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1283,9 +1280,8 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
inputs_embeds, inputs_embeds,
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
] training=training,
)
outputs = self.albert(flat_inputs, training=training)
pooled_output = outputs[1] pooled_output = outputs[1]
......
...@@ -36,7 +36,6 @@ from .modeling_tf_utils import ( ...@@ -36,7 +36,6 @@ from .modeling_tf_utils import (
TFQuestionAnsweringLoss, TFQuestionAnsweringLoss,
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFTokenClassificationLoss, TFTokenClassificationLoss,
cast_bool_to_primitive,
get_initializer, get_initializer,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -81,6 +80,7 @@ def gelu(x): ...@@ -81,6 +80,7 @@ def gelu(x):
Also see https://arxiv.org/abs/1606.08415 Also see https://arxiv.org/abs/1606.08415
""" """
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0))) cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))
return x * cdf return x * cdf
...@@ -94,6 +94,7 @@ def gelu_new(x): ...@@ -94,6 +94,7 @@ def gelu_new(x):
`x` with the GELU activation applied. `x` with the GELU activation applied.
""" """
cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf return x * cdf
...@@ -118,7 +119,6 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -118,7 +119,6 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.initializer_range = config.initializer_range self.initializer_range = config.initializer_range
self.position_embeddings = tf.keras.layers.Embedding( self.position_embeddings = tf.keras.layers.Embedding(
config.max_position_embeddings, config.max_position_embeddings,
config.hidden_size, config.hidden_size,
...@@ -149,7 +149,15 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -149,7 +149,15 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
) )
super().build(input_shape) super().build(input_shape)
def call(self, inputs, mode="embedding", training=False): def call(
self,
input_ids=None,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
mode="embedding",
training=False,
):
"""Get token embeddings of inputs. """Get token embeddings of inputs.
Args: Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids) inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
...@@ -165,15 +173,15 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -165,15 +173,15 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
""" """
if mode == "embedding": if mode == "embedding":
return self._embedding(inputs, training=training) return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
elif mode == "linear": elif mode == "linear":
return self._linear(inputs) return self._linear(input_ids)
else: else:
raise ValueError("mode {} is not valid.".format(mode)) raise ValueError("mode {} is not valid.".format(mode))
def _embedding(self, inputs, training=False): def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
"""Applies embedding based on inputs tensor.""" """Applies embedding based on inputs tensor."""
input_ids, position_ids, token_type_ids, inputs_embeds = inputs assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(input_ids)
...@@ -181,19 +189,22 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -181,19 +189,22 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs_embeds)[:-1]
seq_length = input_shape[1] seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
if token_type_ids is None: if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0) token_type_ids = tf.fill(input_shape, 0)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = tf.gather(self.word_embeddings, input_ids) inputs_embeds = tf.gather(self.word_embeddings, input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings) embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings, training=training) embeddings = self.dropout(embeddings, training=training)
return embeddings return embeddings
def _linear(self, inputs): def _linear(self, inputs):
...@@ -205,7 +216,6 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -205,7 +216,6 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
""" """
batch_size = shape_list(inputs)[0] batch_size = shape_list(inputs)[0]
length = shape_list(inputs)[1] length = shape_list(inputs)[1]
x = tf.reshape(inputs, [-1, self.hidden_size]) x = tf.reshape(inputs, [-1, self.hidden_size])
logits = tf.matmul(x, self.word_embeddings, transpose_b=True) logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
...@@ -215,6 +225,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -215,6 +225,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
class TFBertSelfAttention(tf.keras.layers.Layer): class TFBertSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0: if config.hidden_size % config.num_attention_heads != 0:
raise ValueError( raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention " "The hidden size (%d) is not a multiple of the number of attention "
...@@ -225,7 +236,6 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -225,7 +236,6 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
assert config.hidden_size % config.num_attention_heads == 0 assert config.hidden_size % config.num_attention_heads == 0
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = tf.keras.layers.Dense( self.query = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
) )
...@@ -235,21 +245,18 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -235,21 +245,18 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
self.value = tf.keras.layers.Dense( self.value = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
) )
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x, batch_size): def transpose_for_scores(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs, training=False): return tf.transpose(x, perm=[0, 2, 1, 3])
hidden_states, attention_mask, head_mask, output_attentions = inputs
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
batch_size = shape_list(hidden_states)[0] batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states) mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states) mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
...@@ -277,15 +284,11 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -277,15 +284,11 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
attention_probs = attention_probs * head_mask attention_probs = attention_probs * head_mask
context_layer = tf.matmul(attention_probs, value_layer) context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
context_layer = tf.reshape( context_layer = tf.reshape(
context_layer, (batch_size, -1, self.all_head_size) context_layer, (batch_size, -1, self.all_head_size)
) # (batch_size, seq_len_q, all_head_size) ) # (batch_size, seq_len_q, all_head_size)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
outputs = (
(context_layer, attention_probs) if cast_bool_to_primitive(output_attentions) is True else (context_layer,)
)
return outputs return outputs
...@@ -299,12 +302,11 @@ class TFBertSelfOutput(tf.keras.layers.Layer): ...@@ -299,12 +302,11 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, inputs, training=False): def call(self, hidden_states, input_tensor, training=False):
hidden_states, input_tensor = inputs
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states return hidden_states
...@@ -317,14 +319,13 @@ class TFBertAttention(tf.keras.layers.Layer): ...@@ -317,14 +319,13 @@ class TFBertAttention(tf.keras.layers.Layer):
def prune_heads(self, heads): def prune_heads(self, heads):
raise NotImplementedError raise NotImplementedError
def call(self, inputs, training=False): def call(self, input_tensor, attention_mask, head_mask, output_attentions, training=False):
input_tensor, attention_mask, head_mask, output_attentions = inputs
self_outputs = self.self_attention( self_outputs = self.self_attention(
[input_tensor, attention_mask, head_mask, output_attentions], training=training input_tensor, attention_mask, head_mask, output_attentions, training=training
) )
attention_output = self.dense_output([self_outputs[0], input_tensor], training=training) attention_output = self.dense_output(self_outputs[0], input_tensor, training=training)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs return outputs
...@@ -334,6 +335,7 @@ class TFBertIntermediate(tf.keras.layers.Layer): ...@@ -334,6 +335,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
) )
if isinstance(config.hidden_act, str): if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act] self.intermediate_act_fn = ACT2FN[config.hidden_act]
else: else:
...@@ -342,6 +344,7 @@ class TFBertIntermediate(tf.keras.layers.Layer): ...@@ -342,6 +344,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -354,12 +357,11 @@ class TFBertOutput(tf.keras.layers.Layer): ...@@ -354,12 +357,11 @@ class TFBertOutput(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, inputs, training=False): def call(self, hidden_states, input_tensor, training=False):
hidden_states, input_tensor = inputs
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states return hidden_states
...@@ -370,16 +372,15 @@ class TFBertLayer(tf.keras.layers.Layer): ...@@ -370,16 +372,15 @@ class TFBertLayer(tf.keras.layers.Layer):
self.intermediate = TFBertIntermediate(config, name="intermediate") self.intermediate = TFBertIntermediate(config, name="intermediate")
self.bert_output = TFBertOutput(config, name="output") self.bert_output = TFBertOutput(config, name="output")
def call(self, inputs, training=False): def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
hidden_states, attention_mask, head_mask, output_attentions = inputs
attention_outputs = self.attention( attention_outputs = self.attention(
[hidden_states, attention_mask, head_mask, output_attentions], training=training hidden_states, attention_mask, head_mask, output_attentions, training=training
) )
attention_output = attention_outputs[0] attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(attention_output)
layer_output = self.bert_output([intermediate_output, attention_output], training=training) layer_output = self.bert_output(intermediate_output, attention_output, training=training)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs return outputs
...@@ -388,32 +389,34 @@ class TFBertEncoder(tf.keras.layers.Layer): ...@@ -388,32 +389,34 @@ class TFBertEncoder(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)] self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
def call(self, inputs, training=False): def call(self, hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states, training=False):
hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states = inputs
all_hidden_states = () all_hidden_states = ()
all_attentions = () all_attentions = ()
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module( layer_outputs = layer_module(
[hidden_states, attention_mask, head_mask[i], output_attentions], training=training hidden_states, attention_mask, head_mask[i], output_attentions, training=training
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer # Add last layer
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,) outputs = (hidden_states,)
if cast_bool_to_primitive(output_hidden_states) is True:
if output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True:
if output_attentions:
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs # outputs, (hidden states), (attentions) return outputs # outputs, (hidden states), (attentions)
...@@ -432,6 +435,7 @@ class TFBertPooler(tf.keras.layers.Layer): ...@@ -432,6 +435,7 @@ class TFBertPooler(tf.keras.layers.Layer):
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor) pooled_output = self.dense(first_token_tensor)
return pooled_output return pooled_output
...@@ -441,16 +445,19 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer): ...@@ -441,16 +445,19 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
) )
if isinstance(config.hidden_act, str): if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act] self.transform_act_fn = ACT2FN[config.hidden_act]
else: else:
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
return hidden_states return hidden_states
...@@ -472,6 +479,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer): ...@@ -472,6 +479,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
hidden_states = self.transform(hidden_states) hidden_states = self.transform(hidden_states)
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias hidden_states = hidden_states + self.bias
return hidden_states return hidden_states
...@@ -482,6 +490,7 @@ class TFBertMLMHead(tf.keras.layers.Layer): ...@@ -482,6 +490,7 @@ class TFBertMLMHead(tf.keras.layers.Layer):
def call(self, sequence_output): def call(self, sequence_output):
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores
...@@ -494,6 +503,7 @@ class TFBertNSPHead(tf.keras.layers.Layer): ...@@ -494,6 +503,7 @@ class TFBertNSPHead(tf.keras.layers.Layer):
def call(self, pooled_output): def call(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output) seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score return seq_relationship_score
...@@ -507,7 +517,6 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -507,7 +517,6 @@ class TFBertMainLayer(tf.keras.layers.Layer):
self.initializer_range = config.initializer_range self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.embeddings = TFBertEmbeddings(config, name="embeddings") self.embeddings = TFBertEmbeddings(config, name="embeddings")
self.encoder = TFBertEncoder(config, name="encoder") self.encoder = TFBertEncoder(config, name="encoder")
self.pooler = TFBertPooler(config, name="pooler") self.pooler = TFBertPooler(config, name="pooler")
...@@ -605,18 +614,22 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -605,18 +614,22 @@ class TFBertMainLayer(tf.keras.layers.Layer):
head_mask = [None] * self.num_hidden_layers head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers) # head_mask = tf.constant([0] * self.num_hidden_layers)
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training) embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
[embedding_output, extended_attention_mask, head_mask, output_attentions, output_hidden_states], embedding_output,
extended_attention_mask,
head_mask,
output_attentions,
output_hidden_states,
training=training, training=training,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
outputs = (sequence_output, pooled_output,) + encoder_outputs[ outputs = (sequence_output, pooled_output,) + encoder_outputs[
1: 1:
] # add hidden_states and attentions if they are here ] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions) return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
...@@ -1211,8 +1224,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1211,8 +1224,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
if inputs_embeds is not None if inputs_embeds is not None
else None else None
) )
outputs = self.bert(
flat_inputs = [
flat_input_ids, flat_input_ids,
flat_attention_mask, flat_attention_mask,
flat_token_type_ids, flat_token_type_ids,
...@@ -1221,16 +1233,12 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1221,16 +1233,12 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
flat_inputs_embeds, flat_inputs_embeds,
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
] training=training,
)
outputs = self.bert(flat_inputs, training=training)
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training) pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
......
...@@ -27,7 +27,6 @@ from .modeling_tf_utils import ( ...@@ -27,7 +27,6 @@ from .modeling_tf_utils import (
TFCausalLanguageModelingLoss, TFCausalLanguageModelingLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
cast_bool_to_primitive,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
...@@ -87,10 +86,11 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N ...@@ -87,10 +86,11 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
class TFMultiHeadAttention(tf.keras.layers.Layer): class TFMultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model_size, num_heads, **kwargs): def __init__(self, d_model_size, num_heads, output_attentions=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.num_heads = num_heads self.num_heads = num_heads
self.d_model_size = d_model_size self.d_model_size = d_model_size
self.output_attentions = output_attentions
self.depth = int(d_model_size / self.num_heads) self.depth = int(d_model_size / self.num_heads)
...@@ -104,8 +104,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -104,8 +104,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3]) return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs, training=False): def call(self, v, k, q, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
v, k, q, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
batch_size = shape_list(q)[0] batch_size = shape_list(q)[0]
q = self.Wq(q) q = self.Wq(q)
...@@ -121,10 +120,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -121,10 +120,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
k = tf.concat((past_key, k), axis=-2) k = tf.concat((past_key, k), axis=-2)
v = tf.concat((past_value, v), axis=-2) v = tf.concat((past_value, v), axis=-2)
# to cope with keras serialization if use_cache:
use_cache = cast_bool_to_primitive(use_cache, True)
if use_cache is True:
present = tf.stack((k, v), axis=0) present = tf.stack((k, v), axis=0)
else: else:
present = (None,) present = (None,)
...@@ -134,10 +130,11 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -134,10 +130,11 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
attn = output[1] attn = output[1]
original_size_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model_size)) original_size_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model_size))
output = self.dense(original_size_attention) output = self.dense(original_size_attention)
outputs = (output, present) outputs = (output, present)
if cast_bool_to_primitive(output_attentions) is True:
if output_attentions:
outputs = outputs + (attn,) outputs = outputs + (attn,)
return outputs return outputs
...@@ -156,10 +153,16 @@ class TFPointWiseFeedForwardLayer(tf.keras.layers.Layer): ...@@ -156,10 +153,16 @@ class TFPointWiseFeedForwardLayer(tf.keras.layers.Layer):
class TFEncoderLayer(tf.keras.layers.Layer): class TFEncoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, **kwargs): def __init__(
self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, output_attentions=False, **kwargs
):
super().__init__(**kwargs) super().__init__(**kwargs)
self.multi_head_attention = TFMultiHeadAttention(d_model_size, num_heads, name="multi_head_attention") self.output_attentions = output_attentions
self.multi_head_attention = TFMultiHeadAttention(
d_model_size, num_heads, output_attentions=self.output_attentions, name="multi_head_attention"
)
self.ffn = TFPointWiseFeedForwardLayer(d_model_size, dff, name="ffn") self.ffn = TFPointWiseFeedForwardLayer(d_model_size, dff, name="ffn")
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1") self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1")
...@@ -168,11 +171,18 @@ class TFEncoderLayer(tf.keras.layers.Layer): ...@@ -168,11 +171,18 @@ class TFEncoderLayer(tf.keras.layers.Layer):
self.dropout1 = tf.keras.layers.Dropout(rate) self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate) self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, inputs, training=False): def call(self, x, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
x, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
normed = self.layernorm1(x) normed = self.layernorm1(x)
attn_outputs = self.multi_head_attention( attn_outputs = self.multi_head_attention(
[normed, normed, normed, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions], normed,
normed,
normed,
mask,
layer_past,
attention_mask,
head_mask,
use_cache,
output_attentions,
training=training, training=training,
) )
attn_output = attn_outputs[0] attn_output = attn_outputs[0]
...@@ -215,6 +225,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -215,6 +225,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
config.dff, config.dff,
config.resid_pdrop, config.resid_pdrop,
config.layer_norm_epsilon, config.layer_norm_epsilon,
self.output_attentions,
name="h_._{}".format(i), name="h_._{}".format(i),
) )
for i in range(config.n_layer) for i in range(config.n_layer)
...@@ -367,31 +378,37 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -367,31 +378,37 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
all_hidden_states = () all_hidden_states = ()
all_attentions = [] all_attentions = []
for i, (h, layer_past) in enumerate(zip(self.h, past)): for i, (h, layer_past) in enumerate(zip(self.h, past)):
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h( outputs = h(
[hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache, output_attentions], hidden_states,
mask,
layer_past,
attention_mask,
head_mask[i],
use_cache,
output_attentions,
training=training, training=training,
) )
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
if use_cache is True: if use_cache:
presents = presents + (present,) presents = presents + (present,)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
all_attentions.append(outputs[2]) all_attentions.append(outputs[2])
hidden_states = self.layernorm(hidden_states) hidden_states = self.layernorm(hidden_states)
hidden_states = tf.reshape(hidden_states, output_shape) hidden_states = tf.reshape(hidden_states, output_shape)
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,) outputs = (hidden_states,)
if use_cache is True: if use_cache:
outputs = outputs + (presents,) outputs = outputs + (presents,)
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning # let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
......
...@@ -37,7 +37,6 @@ from .modeling_tf_utils import ( ...@@ -37,7 +37,6 @@ from .modeling_tf_utils import (
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFSharedEmbeddings, TFSharedEmbeddings,
TFTokenClassificationLoss, TFTokenClassificationLoss,
cast_bool_to_primitive,
get_initializer, get_initializer,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -114,7 +113,7 @@ class TFEmbeddings(tf.keras.layers.Layer): ...@@ -114,7 +113,7 @@ class TFEmbeddings(tf.keras.layers.Layer):
) )
super().build(input_shape) super().build(input_shape)
def call(self, inputs, inputs_embeds=None, mode="embedding", training=False): def call(self, input_ids=None, position_ids=None, inputs_embeds=None, mode="embedding", training=False):
"""Get token embeddings of inputs. """Get token embeddings of inputs.
Args: Args:
inputs: list of two int64 tensors with shape [batch_size, length]: (input_ids, position_ids) inputs: list of two int64 tensors with shape [batch_size, length]: (input_ids, position_ids)
...@@ -130,13 +129,13 @@ class TFEmbeddings(tf.keras.layers.Layer): ...@@ -130,13 +129,13 @@ class TFEmbeddings(tf.keras.layers.Layer):
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
""" """
if mode == "embedding": if mode == "embedding":
return self._embedding(inputs, inputs_embeds=inputs_embeds, training=training) return self._embedding(input_ids, position_ids, inputs_embeds, training=training)
elif mode == "linear": elif mode == "linear":
return self._linear(inputs) return self._linear(input_ids)
else: else:
raise ValueError("mode {} is not valid.".format(mode)) raise ValueError("mode {} is not valid.".format(mode))
def _embedding(self, inputs, inputs_embeds=None, training=False): def _embedding(self, input_ids, position_ids, inputs_embeds, training=False):
""" """
Parameters Parameters
---------- ----------
...@@ -148,11 +147,7 @@ class TFEmbeddings(tf.keras.layers.Layer): ...@@ -148,11 +147,7 @@ class TFEmbeddings(tf.keras.layers.Layer):
embeddings: tf.Tensor(bs, max_seq_length, dim) embeddings: tf.Tensor(bs, max_seq_length, dim)
The embedded tokens (plus position embeddings, no token_type embeddings) The embedded tokens (plus position embeddings, no token_type embeddings)
""" """
if not isinstance(inputs, (tuple, list)): assert not (input_ids is None and inputs_embeds is None)
input_ids = inputs
position_ids = None
else:
input_ids, position_ids = inputs
if input_ids is not None: if input_ids is not None:
seq_length = shape_list(input_ids)[1] seq_length = shape_list(input_ids)[1]
...@@ -194,6 +189,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): ...@@ -194,6 +189,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
self.n_heads = config.n_heads self.n_heads = config.n_heads
self.dim = config.dim self.dim = config.dim
self.dropout = tf.keras.layers.Dropout(config.attention_dropout) self.dropout = tf.keras.layers.Dropout(config.attention_dropout)
self.output_attentions = config.output_attentions
assert self.dim % self.n_heads == 0, f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}" assert self.dim % self.n_heads == 0, f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}"
...@@ -215,7 +211,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): ...@@ -215,7 +211,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
def prune_heads(self, heads): def prune_heads(self, heads):
raise NotImplementedError raise NotImplementedError
def call(self, inputs, training=False): def call(self, query, key, value, mask, head_mask, output_attentions, training=False):
""" """
Parameters Parameters
---------- ----------
...@@ -231,7 +227,6 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): ...@@ -231,7 +227,6 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
context: tf.Tensor(bs, seq_length, dim) context: tf.Tensor(bs, seq_length, dim)
Contextualized layer. Optional: only if `output_attentions=True` Contextualized layer. Optional: only if `output_attentions=True`
""" """
query, key, value, mask, head_mask, output_attentions = inputs
bs, q_length, dim = shape_list(query) bs, q_length, dim = shape_list(query)
k_length = shape_list(key)[1] k_length = shape_list(key)[1]
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim) # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
...@@ -270,7 +265,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): ...@@ -270,7 +265,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
context = unshape(context) # (bs, q_length, dim) context = unshape(context) # (bs, q_length, dim)
context = self.out_lin(context) # (bs, q_length, dim) context = self.out_lin(context) # (bs, q_length, dim)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
return (context, weights) return (context, weights)
else: else:
return (context,) return (context,)
...@@ -310,6 +305,7 @@ class TFTransformerBlock(tf.keras.layers.Layer): ...@@ -310,6 +305,7 @@ class TFTransformerBlock(tf.keras.layers.Layer):
self.hidden_dim = config.hidden_dim self.hidden_dim = config.hidden_dim
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
self.activation = config.activation self.activation = config.activation
self.output_attentions = config.output_attentions
assert ( assert (
config.dim % config.n_heads == 0 config.dim % config.n_heads == 0
...@@ -321,7 +317,7 @@ class TFTransformerBlock(tf.keras.layers.Layer): ...@@ -321,7 +317,7 @@ class TFTransformerBlock(tf.keras.layers.Layer):
self.ffn = TFFFN(config, name="ffn") self.ffn = TFFFN(config, name="ffn")
self.output_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="output_layer_norm") self.output_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="output_layer_norm")
def call(self, inputs, training=False): # removed: src_enc=None, src_len=None def call(self, x, attn_mask, head_mask, output_attentions, training=False): # removed: src_enc=None, src_len=None
""" """
Parameters Parameters
---------- ----------
...@@ -335,11 +331,9 @@ class TFTransformerBlock(tf.keras.layers.Layer): ...@@ -335,11 +331,9 @@ class TFTransformerBlock(tf.keras.layers.Layer):
ffn_output: tf.Tensor(bs, seq_length, dim) ffn_output: tf.Tensor(bs, seq_length, dim)
The output of the transformer block contextualization. The output of the transformer block contextualization.
""" """
x, attn_mask, head_mask, output_attentions = inputs
# Self-Attention # Self-Attention
sa_output = self.attention([x, x, x, attn_mask, head_mask, output_attentions], training=training) sa_output = self.attention(x, x, x, attn_mask, head_mask, output_attentions, training=training)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
# assert type(sa_output) == tuple # assert type(sa_output) == tuple
...@@ -351,7 +345,7 @@ class TFTransformerBlock(tf.keras.layers.Layer): ...@@ -351,7 +345,7 @@ class TFTransformerBlock(tf.keras.layers.Layer):
ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim) ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
output = (ffn_output,) output = (ffn_output,)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
output = (sa_weights,) + output output = (sa_weights,) + output
return output return output
...@@ -360,10 +354,12 @@ class TFTransformer(tf.keras.layers.Layer): ...@@ -360,10 +354,12 @@ class TFTransformer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.n_layers = config.n_layers self.n_layers = config.n_layers
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.layer = [TFTransformerBlock(config, name="layer_._{}".format(i)) for i in range(config.n_layers)] self.layer = [TFTransformerBlock(config, name="layer_._{}".format(i)) for i in range(config.n_layers)]
def call(self, inputs, training=False): def call(self, x, attn_mask, head_mask, output_attentions, output_hidden_states, training=False):
""" """
Parameters Parameters
---------- ----------
...@@ -383,34 +379,32 @@ class TFTransformer(tf.keras.layers.Layer): ...@@ -383,34 +379,32 @@ class TFTransformer(tf.keras.layers.Layer):
Tuple of length n_layers with the attention weights from each layer Tuple of length n_layers with the attention weights from each layer
Optional: only if output_attentions=True Optional: only if output_attentions=True
""" """
x, attn_mask, head_mask, output_attentions, output_hidden_states = inputs
all_hidden_states = () all_hidden_states = ()
all_attentions = () all_attentions = ()
hidden_state = x hidden_state = x
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,) all_hidden_states = all_hidden_states + (hidden_state,)
layer_outputs = layer_module([hidden_state, attn_mask, head_mask[i], output_attentions], training=training) layer_outputs = layer_module(hidden_state, attn_mask, head_mask[i], output_attentions, training=training)
hidden_state = layer_outputs[-1] hidden_state = layer_outputs[-1]
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
assert len(layer_outputs) == 2, f"Incorrect number of outputs {len(layer_outputs)} instead of 2" assert len(layer_outputs) == 2
attentions = layer_outputs[0] attentions = layer_outputs[0]
all_attentions = all_attentions + (attentions,) all_attentions = all_attentions + (attentions,)
else: else:
assert len(layer_outputs) == 1, f"Incorrect number of outputs {len(layer_outputs)} instead of 1" assert len(layer_outputs) == 1, f"Incorrect number of outputs {len(layer_outputs)} instead of 1"
# Add last layer # Add last layer
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,) all_hidden_states = all_hidden_states + (hidden_state,)
outputs = (hidden_state,) outputs = (hidden_state,)
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions) return outputs # last-layer hidden state, (all hidden states), (all attentions)
...@@ -481,6 +475,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): ...@@ -481,6 +475,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.ones(input_shape) # (bs, seq_length) attention_mask = tf.ones(input_shape) # (bs, seq_length)
attention_mask = tf.cast(attention_mask, dtype=tf.float32) attention_mask = tf.cast(attention_mask, dtype=tf.float32)
# Prepare head mask if needed # Prepare head mask if needed
...@@ -491,11 +486,12 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): ...@@ -491,11 +486,12 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
if head_mask is not None: if head_mask is not None:
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.num_hidden_layers head_mask = [None] * self.num_hidden_layers
embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim) embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim)
tfmr_output = self.transformer( tfmr_output = self.transformer(
[embedding_output, attention_mask, head_mask, output_attentions, output_hidden_states], training=training embedding_output, attention_mask, head_mask, output_attentions, output_hidden_states, training=training
) )
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions) return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
...@@ -986,24 +982,21 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic ...@@ -986,24 +982,21 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
if inputs_embeds is not None if inputs_embeds is not None
else None else None
) )
distilbert_output = self.distilbert(
flat_inputs = [
flat_input_ids, flat_input_ids,
flat_attention_mask, flat_attention_mask,
head_mask, head_mask,
flat_inputs_embeds, flat_inputs_embeds,
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
] training=training,
)
distilbert_output = self.distilbert(flat_inputs, training=training)
hidden_state = distilbert_output[0] # (bs, seq_len, dim) hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim) pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim) pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output, training=training) # (bs, dim) pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
outputs = (reshaped_logits,) + distilbert_output[1:] # add hidden states and attention if they are here outputs = (reshaped_logits,) + distilbert_output[1:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
......
...@@ -2,7 +2,8 @@ import logging ...@@ -2,7 +2,8 @@ import logging
import tensorflow as tf import tensorflow as tf
from .configuration_electra import ElectraConfig from transformers import ElectraConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel
from .modeling_tf_utils import ( from .modeling_tf_utils import (
...@@ -71,7 +72,15 @@ class TFElectraEmbeddings(tf.keras.layers.Layer): ...@@ -71,7 +72,15 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
) )
super().build(input_shape) super().build(input_shape)
def call(self, inputs, mode="embedding", training=False): def call(
self,
input_ids=None,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
mode="embedding",
training=False,
):
"""Get token embeddings of inputs. """Get token embeddings of inputs.
Args: Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids) inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
...@@ -87,15 +96,15 @@ class TFElectraEmbeddings(tf.keras.layers.Layer): ...@@ -87,15 +96,15 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
""" """
if mode == "embedding": if mode == "embedding":
return self._embedding(inputs, training=training) return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
elif mode == "linear": elif mode == "linear":
return self._linear(inputs) return self._linear(input_ids)
else: else:
raise ValueError("mode {} is not valid.".format(mode)) raise ValueError("mode {} is not valid.".format(mode))
def _embedding(self, inputs, training=False): def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
"""Applies embedding based on inputs tensor.""" """Applies embedding based on inputs tensor."""
input_ids, position_ids, token_type_ids, inputs_embeds = inputs assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(input_ids)
...@@ -289,13 +298,17 @@ class TFElectraMainLayer(TFElectraPreTrainedModel): ...@@ -289,13 +298,17 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
head_mask = self.get_head_mask(head_mask) head_mask = self.get_head_mask(head_mask)
hidden_states = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training) hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
if hasattr(self, "embeddings_project"): if hasattr(self, "embeddings_project"):
hidden_states = self.embeddings_project(hidden_states, training=training) hidden_states = self.embeddings_project(hidden_states, training=training)
hidden_states = self.encoder( hidden_states = self.encoder(
[hidden_states, extended_attention_mask, head_mask, output_attentions, output_hidden_states], hidden_states,
extended_attention_mask,
head_mask,
output_attentions,
output_hidden_states,
training=training, training=training,
) )
......
...@@ -22,7 +22,7 @@ import tensorflow as tf ...@@ -22,7 +22,7 @@ import tensorflow as tf
from .configuration_flaubert import FlaubertConfig from .configuration_flaubert import FlaubertConfig
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_utils import cast_bool_to_primitive, keras_serializable, shape_list from .modeling_tf_utils import keras_serializable, shape_list
from .modeling_tf_xlm import ( from .modeling_tf_xlm import (
TFXLMForMultipleChoice, TFXLMForMultipleChoice,
TFXLMForQuestionAnsweringSimple, TFXLMForQuestionAnsweringSimple,
...@@ -274,10 +274,10 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -274,10 +274,10 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
# self attention # self attention
if not self.pre_norm: if not self.pre_norm:
attn_outputs = self.attentions[i]( attn_outputs = self.attentions[i](
[tensor, attn_mask, None, cache, head_mask[i], output_attentions], training=training tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training) attn = self.dropout(attn, training=training)
tensor = tensor + attn tensor = tensor + attn
...@@ -285,10 +285,10 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -285,10 +285,10 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
else: else:
tensor_normalized = self.layer_norm1[i](tensor) tensor_normalized = self.layer_norm1[i](tensor)
attn_outputs = self.attentions[i]( attn_outputs = self.attentions[i](
[tensor_normalized, attn_mask, None, cache, head_mask[i]], training=training tensor_normalized, attn_mask, None, cache, head_mask[i], training=training
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training) attn = self.dropout(attn, training=training)
tensor = tensor + attn tensor = tensor + attn
...@@ -311,7 +311,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -311,7 +311,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
tensor = tensor * mask[..., tf.newaxis] tensor = tensor * mask[..., tf.newaxis]
# Add last hidden state # Add last hidden state
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True: if output_hidden_states:
hidden_states = hidden_states + (tensor,) hidden_states = hidden_states + (tensor,)
# update cache length # update cache length
...@@ -322,9 +322,9 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -322,9 +322,9 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
# tensor = tensor.transpose(0, 1) # tensor = tensor.transpose(0, 1)
outputs = (tensor,) outputs = (tensor,)
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True: if output_hidden_states:
outputs = outputs + (hidden_states,) outputs = outputs + (hidden_states,)
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True: if output_attentions:
outputs = outputs + (attentions,) outputs = outputs + (attentions,)
return outputs # outputs, (hidden_states), (attentions) return outputs # outputs, (hidden_states), (attentions)
......
...@@ -29,7 +29,6 @@ from .modeling_tf_utils import ( ...@@ -29,7 +29,6 @@ from .modeling_tf_utils import (
TFPreTrainedModel, TFPreTrainedModel,
TFSequenceSummary, TFSequenceSummary,
TFSharedEmbeddings, TFSharedEmbeddings,
cast_bool_to_primitive,
get_initializer, get_initializer,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -75,6 +74,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -75,6 +74,7 @@ class TFAttention(tf.keras.layers.Layer):
self.n_head = config.n_head self.n_head = config.n_head
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
self.output_attentions = config.output_attentions
self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn") self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn")
self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj") self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj")
...@@ -95,8 +95,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -95,8 +95,7 @@ class TFAttention(tf.keras.layers.Layer):
m = i >= j - ns + nd m = i >= j - ns + nd
return tf.cast(m, dtype) return tf.cast(m, dtype)
def _attn(self, inputs, training=False): def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False):
q, k, v, attention_mask, head_mask, output_attentions = inputs
# q, k, v have shape [batch, heads, sequence, features] # q, k, v have shape [batch, heads, sequence, features]
w = tf.matmul(q, k, transpose_b=True) w = tf.matmul(q, k, transpose_b=True)
if self.scale: if self.scale:
...@@ -121,7 +120,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -121,7 +120,7 @@ class TFAttention(tf.keras.layers.Layer):
w = w * head_mask w = w * head_mask
outputs = [tf.matmul(w, v)] outputs = [tf.matmul(w, v)]
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
outputs.append(w) outputs.append(w)
return outputs return outputs
...@@ -137,9 +136,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -137,9 +136,7 @@ class TFAttention(tf.keras.layers.Layer):
x = tf.reshape(x, new_x_shape) x = tf.reshape(x, new_x_shape)
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features) return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
def call(self, inputs, training=False): def call(self, x, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
x, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
x = self.c_attn(x) x = self.c_attn(x)
query, key, value = tf.split(x, 3, axis=2) query, key, value = tf.split(x, 3, axis=2)
query = self.split_heads(query) query = self.split_heads(query)
...@@ -151,12 +148,12 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -151,12 +148,12 @@ class TFAttention(tf.keras.layers.Layer):
value = tf.concat([past_value, value], axis=-2) value = tf.concat([past_value, value], axis=-2)
# to cope with keras serialization # to cope with keras serialization
if cast_bool_to_primitive(use_cache, True) is True: if use_cache:
present = tf.stack([key, value], axis=0) present = tf.stack([key, value], axis=0)
else: else:
present = (None,) present = (None,)
attn_outputs = self._attn([query, key, value, attention_mask, head_mask, output_attentions], training=training) attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions, training=training)
a = attn_outputs[0] a = attn_outputs[0]
a = self.merge_heads(a) a = self.merge_heads(a)
...@@ -192,12 +189,10 @@ class TFBlock(tf.keras.layers.Layer): ...@@ -192,12 +189,10 @@ class TFBlock(tf.keras.layers.Layer):
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2") self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
self.mlp = TFMLP(4 * nx, config, name="mlp") self.mlp = TFMLP(4 * nx, config, name="mlp")
def call(self, inputs, training=False): def call(self, x, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
x, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
a = self.ln_1(x) a = self.ln_1(x)
output_attn = self.attn( output_attn = self.attn(
[a, layer_past, attention_mask, head_mask, use_cache, output_attentions], training=training a, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=training
) )
a = output_attn[0] # output_attn: a, present, (attentions) a = output_attn[0] # output_attn: a, present, (attentions)
x = x + a x = x + a
...@@ -223,6 +218,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -223,6 +218,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
self.num_hidden_layers = config.n_layer self.num_hidden_layers = config.n_layer
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.n_embd = config.n_embd self.n_embd = config.n_embd
self.output_hidden_states = self.output_hidden_states
self.output_attentions = self.output_attentions
self.wte = TFSharedEmbeddings( self.wte = TFSharedEmbeddings(
config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte" config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte"
...@@ -362,34 +359,39 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -362,34 +359,39 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
all_attentions = [] all_attentions = []
all_hidden_states = () all_hidden_states = ()
for i, (block, layer_past) in enumerate(zip(self.h, past)): for i, (block, layer_past) in enumerate(zip(self.h, past)):
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block( outputs = block(
[hidden_states, layer_past, attention_mask, head_mask[i], use_cache, output_attentions], hidden_states,
layer_past,
attention_mask,
head_mask[i],
use_cache,
output_attentions,
training=training, training=training,
) )
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
presents = presents + (present,) presents = presents + (present,)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
all_attentions.append(outputs[2]) all_attentions.append(outputs[2])
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
hidden_states = tf.reshape(hidden_states, output_shape) hidden_states = tf.reshape(hidden_states, output_shape)
# Add last hidden state # Add last hidden state
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,) outputs = (hidden_states,)
if use_cache is True: if use_cache:
outputs = outputs + (presents,) outputs = outputs + (presents,)
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning # let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
...@@ -738,13 +740,11 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -738,13 +740,11 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
input_shapes = shape_list(inputs_embeds)[:-1] input_shapes = shape_list(inputs_embeds)[:-1]
seq_length = input_shapes[-1] seq_length = input_shapes[-1]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
transformer_outputs = self.transformer(
flat_inputs = [
flat_input_ids, flat_input_ids,
past, past,
flat_attention_mask, flat_attention_mask,
...@@ -755,18 +755,13 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -755,18 +755,13 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
use_cache, use_cache,
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
] training=training,
)
transformer_outputs = self.transformer(flat_inputs, training=training)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:]) hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
lm_logits = self.transformer.wte(hidden_states, mode="linear") lm_logits = self.transformer.wte(hidden_states, mode="linear")
mc_logits = self.multiple_choice_head([hidden_states, mc_token_ids], training=training) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
mc_logits = tf.squeeze(mc_logits, axis=-1) mc_logits = tf.squeeze(mc_logits, axis=-1)
outputs = (lm_logits, mc_logits) + transformer_outputs[1:] outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
return outputs # lm logits, mc logits, presents, (all hidden_states), (attentions) return outputs # lm logits, mc logits, presents, (all hidden_states), (attentions)
...@@ -35,7 +35,6 @@ from .modeling_tf_utils import ( ...@@ -35,7 +35,6 @@ from .modeling_tf_utils import (
TFQuestionAnsweringLoss, TFQuestionAnsweringLoss,
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFTokenClassificationLoss, TFTokenClassificationLoss,
cast_bool_to_primitive,
get_initializer, get_initializer,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -130,7 +129,15 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer): ...@@ -130,7 +129,15 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
) )
super().build(input_shape) super().build(input_shape)
def call(self, inputs, mode="embedding", training=False): def call(
self,
input_ids=None,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
mode="embedding",
training=False,
):
"""Get token embeddings of inputs. """Get token embeddings of inputs.
Args: Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids) inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
...@@ -146,15 +153,15 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer): ...@@ -146,15 +153,15 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
""" """
if mode == "embedding": if mode == "embedding":
return self._embedding(inputs, training=training) return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
elif mode == "linear": elif mode == "linear":
return self._linear(inputs) return self._linear(input_ids)
else: else:
raise ValueError("mode {} is not valid.".format(mode)) raise ValueError("mode {} is not valid.".format(mode))
def _embedding(self, inputs, training=False): def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
"""Applies embedding based on inputs tensor.""" """Applies embedding based on inputs tensor."""
input_ids, position_ids, token_type_ids, inputs_embeds = inputs assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(input_ids)
...@@ -196,6 +203,7 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer): ...@@ -196,6 +203,7 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
embeddings = inputs_embeds + position_embeddings + token_type_embeddings embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings) embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings, training=training) embeddings = self.dropout(embeddings, training=training)
return embeddings return embeddings
def _linear(self, inputs): def _linear(self, inputs):
...@@ -224,6 +232,7 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer): ...@@ -224,6 +232,7 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer):
) )
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.output_attentions = config.output_attentions
assert config.hidden_size % config.num_attention_heads == 0 assert config.hidden_size % config.num_attention_heads == 0
self.attention_head_size = int(config.true_hidden_size / config.num_attention_heads) self.attention_head_size = int(config.true_hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
...@@ -244,14 +253,13 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer): ...@@ -244,14 +253,13 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3]) return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs, training=False): def call(
query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions = inputs self, query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions, training=False
):
batch_size = shape_list(attention_mask)[0] batch_size = shape_list(attention_mask)[0]
mixed_query_layer = self.query(query_tensor) mixed_query_layer = self.query(query_tensor)
mixed_key_layer = self.key(key_tensor) mixed_key_layer = self.key(key_tensor)
mixed_value_layer = self.value(value_tensor) mixed_value_layer = self.value(value_tensor)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
...@@ -285,9 +293,7 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer): ...@@ -285,9 +293,7 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer):
context_layer, (batch_size, -1, self.all_head_size) context_layer, (batch_size, -1, self.all_head_size)
) # (batch_size, seq_len_q, all_head_size) ) # (batch_size, seq_len_q, all_head_size)
outputs = ( outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
(context_layer, attention_probs) if cast_bool_to_primitive(output_attentions) is True else (context_layer,)
)
return outputs return outputs
...@@ -305,8 +311,7 @@ class TFMobileBertSelfOutput(tf.keras.layers.Layer): ...@@ -305,8 +311,7 @@ class TFMobileBertSelfOutput(tf.keras.layers.Layer):
if not self.use_bottleneck: if not self.use_bottleneck:
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, inputs, training=False): def call(self, hidden_states, residual_tensor, training=False):
hidden_states, residual_tensor = inputs
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
if not self.use_bottleneck: if not self.use_bottleneck:
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
...@@ -323,13 +328,22 @@ class TFMobileBertAttention(tf.keras.layers.Layer): ...@@ -323,13 +328,22 @@ class TFMobileBertAttention(tf.keras.layers.Layer):
def prune_heads(self, heads): def prune_heads(self, heads):
raise NotImplementedError raise NotImplementedError
def call(self, inputs, training=False): def call(
query_tensor, key_tensor, value_tensor, layer_input, attention_mask, head_mask, output_attentions = inputs self,
query_tensor,
key_tensor,
value_tensor,
layer_input,
attention_mask,
head_mask,
output_attentions,
training=False,
):
self_outputs = self.self( self_outputs = self.self(
[query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions], training=training query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions, training=training
) )
attention_output = self.mobilebert_output([self_outputs[0], layer_input], training=training)
attention_output = self.mobilebert_output(self_outputs[0], layer_input, training=training)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs return outputs
...@@ -349,8 +363,7 @@ class TFOutputBottleneck(tf.keras.layers.Layer): ...@@ -349,8 +363,7 @@ class TFOutputBottleneck(tf.keras.layers.Layer):
) )
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, inputs, training=False): def call(self, hidden_states, residual_tensor, training=False):
hidden_states, residual_tensor = inputs
layer_outputs = self.dense(hidden_states) layer_outputs = self.dense(hidden_states)
layer_outputs = self.dropout(layer_outputs, training=training) layer_outputs = self.dropout(layer_outputs, training=training)
layer_outputs = self.LayerNorm(layer_outputs + residual_tensor) layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
...@@ -372,16 +385,14 @@ class TFMobileBertOutput(tf.keras.layers.Layer): ...@@ -372,16 +385,14 @@ class TFMobileBertOutput(tf.keras.layers.Layer):
else: else:
self.bottleneck = TFOutputBottleneck(config, name="bottleneck") self.bottleneck = TFOutputBottleneck(config, name="bottleneck")
def call(self, inputs, training=False): def call(self, hidden_states, residual_tensor_1, residual_tensor_2, training=False):
hidden_states, residual_tensor_1, residual_tensor_2 = inputs
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
if not self.use_bottleneck: if not self.use_bottleneck:
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.LayerNorm(hidden_states + residual_tensor_1) hidden_states = self.LayerNorm(hidden_states + residual_tensor_1)
else: else:
hidden_states = self.LayerNorm(hidden_states + residual_tensor_1) hidden_states = self.LayerNorm(hidden_states + residual_tensor_1)
hidden_states = self.bottleneck([hidden_states, residual_tensor_2]) hidden_states = self.bottleneck(hidden_states, residual_tensor_2)
return hidden_states return hidden_states
...@@ -466,7 +477,6 @@ class TFMobileBertLayer(tf.keras.layers.Layer): ...@@ -466,7 +477,6 @@ class TFMobileBertLayer(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.use_bottleneck = config.use_bottleneck self.use_bottleneck = config.use_bottleneck
self.num_feedforward_networks = config.num_feedforward_networks self.num_feedforward_networks = config.num_feedforward_networks
self.attention = TFMobileBertAttention(config, name="attention") self.attention = TFMobileBertAttention(config, name="attention")
self.intermediate = TFMobileBertIntermediate(config, name="intermediate") self.intermediate = TFMobileBertIntermediate(config, name="intermediate")
self.mobilebert_output = TFMobileBertOutput(config, name="output") self.mobilebert_output = TFMobileBertOutput(config, name="output")
...@@ -478,16 +488,20 @@ class TFMobileBertLayer(tf.keras.layers.Layer): ...@@ -478,16 +488,20 @@ class TFMobileBertLayer(tf.keras.layers.Layer):
TFFFNLayer(config, name="ffn.{}".format(i)) for i in range(config.num_feedforward_networks - 1) TFFFNLayer(config, name="ffn.{}".format(i)) for i in range(config.num_feedforward_networks - 1)
] ]
def call(self, inputs, training=False): def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
hidden_states, attention_mask, head_mask, output_attentions = inputs
if self.use_bottleneck: if self.use_bottleneck:
query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states) query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states)
else: else:
query_tensor, key_tensor, value_tensor, layer_input = [hidden_states] * 4 query_tensor, key_tensor, value_tensor, layer_input = [hidden_states] * 4
attention_outputs = self.attention( attention_outputs = self.attention(
[query_tensor, key_tensor, value_tensor, layer_input, attention_mask, head_mask, output_attentions], query_tensor,
key_tensor,
value_tensor,
layer_input,
attention_mask,
head_mask,
output_attentions,
training=training, training=training,
) )
...@@ -500,48 +514,57 @@ class TFMobileBertLayer(tf.keras.layers.Layer): ...@@ -500,48 +514,57 @@ class TFMobileBertLayer(tf.keras.layers.Layer):
s += (attention_output,) s += (attention_output,)
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(attention_output)
layer_output = self.mobilebert_output( layer_output = self.mobilebert_output(intermediate_output, attention_output, hidden_states, training=training)
[intermediate_output, attention_output, hidden_states], training=training
)
outputs = ( outputs = (
(layer_output,) (layer_output,)
+ attention_outputs[1:] + attention_outputs[1:]
+ (0, query_tensor, key_tensor, value_tensor, layer_input, attention_output, intermediate_output) + (
tf.constant(0),
query_tensor,
key_tensor,
value_tensor,
layer_input,
attention_output,
intermediate_output,
)
+ s + s
) # add attentions if we output them ) # add attentions if we output them
return outputs return outputs
class TFMobileBertEncoder(tf.keras.layers.Layer): class TFMobileBertEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layer = [TFMobileBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)] self.layer = [TFMobileBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
def call(self, inputs, training=False): def call(self, hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states, training=False):
hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states = inputs
all_hidden_states = () all_hidden_states = ()
all_attentions = () all_attentions = ()
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module( layer_outputs = layer_module(
[hidden_states, attention_mask, head_mask[i], output_attentions], training=training hidden_states, attention_mask, head_mask[i], output_attentions, training=training
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer # Add last layer
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,) outputs = (hidden_states,)
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs # outputs, (hidden states), (attentions) return outputs # outputs, (hidden states), (attentions)
...@@ -732,11 +755,14 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): ...@@ -732,11 +755,14 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.num_hidden_layers head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training) embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
[embedding_output, extended_attention_mask, head_mask, output_attentions, output_hidden_states], embedding_output,
extended_attention_mask,
head_mask,
output_attentions,
output_hidden_states,
training=training, training=training,
) )
...@@ -1360,8 +1386,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic ...@@ -1360,8 +1386,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
if inputs_embeds is not None if inputs_embeds is not None
else None else None
) )
outputs = self.mobilebert(
flat_inputs = [
flat_input_ids, flat_input_ids,
flat_attention_mask, flat_attention_mask,
flat_token_type_ids, flat_token_type_ids,
...@@ -1370,16 +1395,12 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic ...@@ -1370,16 +1395,12 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
flat_inputs_embeds, flat_inputs_embeds,
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
] training=training,
)
outputs = self.mobilebert(flat_inputs, training=training)
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training) pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
......
...@@ -29,7 +29,6 @@ from .modeling_tf_utils import ( ...@@ -29,7 +29,6 @@ from .modeling_tf_utils import (
TFPreTrainedModel, TFPreTrainedModel,
TFSequenceSummary, TFSequenceSummary,
TFSharedEmbeddings, TFSharedEmbeddings,
cast_bool_to_primitive,
get_initializer, get_initializer,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -84,6 +83,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -84,6 +83,7 @@ class TFAttention(tf.keras.layers.Layer):
self.n_head = config.n_head self.n_head = config.n_head
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
self.output_attentions = config.output_attentions
self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn") self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn")
self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj") self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj")
...@@ -104,8 +104,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -104,8 +104,7 @@ class TFAttention(tf.keras.layers.Layer):
m = i >= j - ns + nd m = i >= j - ns + nd
return tf.cast(m, dtype) return tf.cast(m, dtype)
def _attn(self, inputs, training=False): def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False):
q, k, v, attention_mask, head_mask, output_attentions = inputs
# q, k, v have shape [batch, heads, sequence, features] # q, k, v have shape [batch, heads, sequence, features]
w = tf.matmul(q, k, transpose_b=True) w = tf.matmul(q, k, transpose_b=True)
if self.scale: if self.scale:
...@@ -130,7 +129,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -130,7 +129,7 @@ class TFAttention(tf.keras.layers.Layer):
w = w * head_mask w = w * head_mask
outputs = [tf.matmul(w, v)] outputs = [tf.matmul(w, v)]
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
outputs.append(w) outputs.append(w)
return outputs return outputs
...@@ -146,16 +145,14 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -146,16 +145,14 @@ class TFAttention(tf.keras.layers.Layer):
x = tf.reshape(x, new_x_shape) x = tf.reshape(x, new_x_shape)
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features) return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
def call(self, inputs, training=False): def call(self, x, attention_mask, head_mask, output_attentions, training=False):
x, attention_mask, head_mask, output_attentions = inputs
x = self.c_attn(x) x = self.c_attn(x)
query, key, value = tf.split(x, 3, axis=2) query, key, value = tf.split(x, 3, axis=2)
query = self.split_heads(query) query = self.split_heads(query)
key = self.split_heads(key) key = self.split_heads(key)
value = self.split_heads(value) value = self.split_heads(value)
attn_outputs = self._attn([query, key, value, attention_mask, head_mask, output_attentions], training=training) attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions, training=training)
a = attn_outputs[0] a = attn_outputs[0]
a = self.merge_heads(a) a = self.merge_heads(a)
...@@ -191,10 +188,8 @@ class TFBlock(tf.keras.layers.Layer): ...@@ -191,10 +188,8 @@ class TFBlock(tf.keras.layers.Layer):
self.mlp = TFMLP(4 * nx, config, name="mlp") self.mlp = TFMLP(4 * nx, config, name="mlp")
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2") self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
def call(self, inputs, training=False): def call(self, x, attention_mask, head_mask, output_attentions, training=False):
x, attention_mask, head_mask, output_attentions = inputs output_attn = self.attn(x, attention_mask, head_mask, output_attentions, training=training)
output_attn = self.attn([x, attention_mask, head_mask, output_attentions], training=training)
a = output_attn[0] # output_attn: a, (attentions) a = output_attn[0] # output_attn: a, (attentions)
n = self.ln_1(x + a) n = self.ln_1(x + a)
...@@ -341,23 +336,23 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -341,23 +336,23 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
all_attentions = [] all_attentions = []
all_hidden_states = () all_hidden_states = ()
for i, block in enumerate(self.h): for i, block in enumerate(self.h):
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block([hidden_states, attention_mask, head_mask[i], output_attentions], training=training) outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions, training=training)
hidden_states = outputs[0] hidden_states = outputs[0]
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
all_attentions.append(outputs[1]) all_attentions.append(outputs[1])
hidden_states = tf.reshape(hidden_states, output_shape) hidden_states = tf.reshape(hidden_states, output_shape)
# Add last hidden state # Add last hidden state
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,) outputs = (hidden_states,)
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning # let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
...@@ -671,13 +666,11 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ...@@ -671,13 +666,11 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
input_shapes = shape_list(inputs_embeds)[:-1] input_shapes = shape_list(inputs_embeds)[:-1]
seq_length = input_shapes[-1] seq_length = input_shapes[-1]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
transformer_outputs = self.transformer(
flat_inputs = [
flat_input_ids, flat_input_ids,
flat_attention_mask, flat_attention_mask,
flat_token_type_ids, flat_token_type_ids,
...@@ -686,18 +679,13 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ...@@ -686,18 +679,13 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
inputs_embeds, inputs_embeds,
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
] training=training,
)
transformer_outputs = self.transformer(flat_inputs, training=training)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:]) hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear") lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
mc_logits = self.multiple_choice_head([hidden_states, mc_token_ids], training=training) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
mc_logits = tf.squeeze(mc_logits, axis=-1) mc_logits = tf.squeeze(mc_logits, axis=-1)
outputs = (lm_logits, mc_logits) + transformer_outputs[1:] outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
return outputs # lm logits, mc logits, (all hidden_states), (attentions) return outputs # lm logits, mc logits, (all hidden_states), (attentions)
...@@ -86,9 +86,9 @@ class TFRobertaEmbeddings(TFBertEmbeddings): ...@@ -86,9 +86,9 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
position_ids = tf.range(self.padding_idx + 1, seq_length + self.padding_idx + 1, dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(self.padding_idx + 1, seq_length + self.padding_idx + 1, dtype=tf.int32)[tf.newaxis, :]
return position_ids return position_ids
def _embedding(self, inputs, training=False): def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
"""Applies embedding based on inputs tensor.""" """Applies embedding based on inputs tensor."""
input_ids, position_ids, token_type_ids, inputs_embeds = inputs assert not (input_ids is None and inputs_embeds is None)
if position_ids is None: if position_ids is None:
if input_ids is not None: if input_ids is not None:
...@@ -97,7 +97,7 @@ class TFRobertaEmbeddings(TFBertEmbeddings): ...@@ -97,7 +97,7 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
else: else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
return super()._embedding([input_ids, position_ids, token_type_ids, inputs_embeds], training=training) return super()._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
@keras_serializable @keras_serializable
...@@ -546,8 +546,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss) ...@@ -546,8 +546,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
outputs = self.roberta(
flat_inputs = [
flat_input_ids, flat_input_ids,
flat_attention_mask, flat_attention_mask,
flat_token_type_ids, flat_token_type_ids,
...@@ -556,16 +555,12 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss) ...@@ -556,16 +555,12 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
inputs_embeds, inputs_embeds,
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
] training=training,
)
outputs = self.roberta(flat_inputs, training=training)
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training) pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
......
...@@ -115,6 +115,7 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -115,6 +115,7 @@ class TFT5Attention(tf.keras.layers.Layer):
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.use_cache = config.use_cache self.use_cache = config.use_cache
self.has_relative_attention_bias = has_relative_attention_bias self.has_relative_attention_bias = has_relative_attention_bias
self.output_attentions = config.output_attentions
self.relative_attention_num_buckets = config.relative_attention_num_buckets self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.d_model = config.d_model self.d_model = config.d_model
...@@ -296,7 +297,7 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -296,7 +297,7 @@ class TFT5Attention(tf.keras.layers.Layer):
outputs = (context,) + present_key_value_state outputs = (context,) + present_key_value_state
if cast_bool_to_primitive(output_attentions, True) is True: if output_attentions:
outputs = outputs + (weights,) outputs = outputs + (weights,)
if self.has_relative_attention_bias: if self.has_relative_attention_bias:
outputs = outputs + (position_bias,) outputs = outputs + (position_bias,)
...@@ -699,7 +700,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -699,7 +700,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
hidden_states = self.dropout(inputs_embeds, training=training) hidden_states = self.dropout(inputs_embeds, training=training)
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)): for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)):
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -727,23 +728,23 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -727,23 +728,23 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# append next layer key value states # append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,) present_key_value_states = present_key_value_states + (present_key_value_state,)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
all_attentions = all_attentions + (layer_outputs[2],) all_attentions = all_attentions + (layer_outputs[2],)
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
# Add last layer # Add last layer
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,) outputs = (hidden_states,)
# need to check if is decoder here as well for special cases when using keras compile # need to check if is decoder here as well for special cases when using keras compile
if cast_bool_to_primitive(use_cache, self.use_cache) is True and self.is_decoder: if cast_bool_to_primitive(use_cache, self.use_cache) is True and self.is_decoder:
outputs = outputs + (present_key_value_states,) outputs = outputs + (present_key_value_states,)
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions) return outputs # last-layer hidden state, (all hidden states), (all attentions)
......
...@@ -24,13 +24,7 @@ import tensorflow as tf ...@@ -24,13 +24,7 @@ import tensorflow as tf
from .configuration_transfo_xl import TransfoXLConfig from .configuration_transfo_xl import TransfoXLConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
from .modeling_tf_utils import ( from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
TFPreTrainedModel,
cast_bool_to_primitive,
get_initializer,
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding from .tokenization_utils import BatchEncoding
...@@ -119,6 +113,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer): ...@@ -119,6 +113,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
r_w_bias=None, r_w_bias=None,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
init_std=0.02, init_std=0.02,
output_attentions=False,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -127,6 +122,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer): ...@@ -127,6 +122,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
self.d_model = d_model self.d_model = d_model
self.d_head = d_head self.d_head = d_head
self.dropout = dropout self.dropout = dropout
self.output_attentions = output_attentions
self.qkv_net = tf.keras.layers.Dense( self.qkv_net = tf.keras.layers.Dense(
3 * n_head * d_head, kernel_initializer=get_initializer(init_std), use_bias=False, name="qkv_net" 3 * n_head * d_head, kernel_initializer=get_initializer(init_std), use_bias=False, name="qkv_net"
...@@ -175,8 +171,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer): ...@@ -175,8 +171,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
return x return x
def call(self, inputs, training=False): def call(self, w, r, attn_mask, mems, head_mask, output_attentions, training=False):
w, r, attn_mask, mems, head_mask, output_attentions = inputs
qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1] qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1]
if mems is not None: if mems is not None:
...@@ -249,7 +244,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer): ...@@ -249,7 +244,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
# residual connection + layer normalization # residual connection + layer normalization
outputs = [self.layer_norm(w + attn_out)] outputs = [self.layer_norm(w + attn_out)]
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
outputs.append(attn_prob) outputs.append(attn_prob)
return outputs return outputs
...@@ -272,6 +267,7 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer): ...@@ -272,6 +267,7 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
r_r_bias=None, r_r_bias=None,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
init_std=0.02, init_std=0.02,
output_attentions=False,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -290,6 +286,7 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer): ...@@ -290,6 +286,7 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
r_r_bias=r_r_bias, r_r_bias=r_r_bias,
init_std=init_std, init_std=init_std,
layer_norm_epsilon=layer_norm_epsilon, layer_norm_epsilon=layer_norm_epsilon,
output_attentions=output_attentions,
name="dec_attn", name="dec_attn",
) )
self.pos_ff = TFPositionwiseFF( self.pos_ff = TFPositionwiseFF(
...@@ -302,11 +299,8 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer): ...@@ -302,11 +299,8 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
name="pos_ff", name="pos_ff",
) )
def call(self, inputs, training=False): def call(self, dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions, training=False):
dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions = inputs attn_outputs = self.dec_attn(dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions, training=training)
attn_outputs = self.dec_attn(
[dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions], training=training
)
ff_output = self.pos_ff(attn_outputs[0], training=training) ff_output = self.pos_ff(attn_outputs[0], training=training)
outputs = [ff_output] + attn_outputs[1:] outputs = [ff_output] + attn_outputs[1:]
...@@ -443,6 +437,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -443,6 +437,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
r_r_bias=None if self.untie_r else self.r_r_bias, r_r_bias=None if self.untie_r else self.r_r_bias,
layer_norm_epsilon=config.layer_norm_epsilon, layer_norm_epsilon=config.layer_norm_epsilon,
init_std=config.init_std, init_std=config.init_std,
output_attentions=self.output_attentions,
name="layers_._{}".format(i), name="layers_._{}".format(i),
) )
) )
...@@ -625,10 +620,10 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -625,10 +620,10 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
hids.append(core_out) hids.append(core_out)
mems_i = None if mems is None else mems[i] mems_i = None if mems is None else mems[i]
layer_outputs = layer( layer_outputs = layer(
[core_out, pos_emb, dec_attn_mask, mems_i, head_mask[i], output_attentions], training=training, core_out, pos_emb, dec_attn_mask, mems_i, head_mask[i], output_attentions, training=training,
) )
core_out = layer_outputs[0] core_out = layer_outputs[0]
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
attentions.append(layer_outputs[1]) attentions.append(layer_outputs[1])
else: # learnable embeddings and absolute embeddings else: # learnable embeddings and absolute embeddings
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
...@@ -639,12 +634,12 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -639,12 +634,12 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
# We transpose back here to shape [bsz, len, hidden_dim] # We transpose back here to shape [bsz, len, hidden_dim]
outputs = [tf.transpose(core_out, perm=(1, 0, 2)), new_mems] outputs = [tf.transpose(core_out, perm=(1, 0, 2)), new_mems]
if cast_bool_to_primitive(output_hidden_states): if output_hidden_states:
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim] # Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
hids.append(core_out) hids.append(core_out)
hids = list(tf.transpose(t, perm=(1, 0, 2)) for t in hids) hids = list(tf.transpose(t, perm=(1, 0, 2)) for t in hids)
outputs.append(hids) outputs.append(hids)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len] # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
attentions = list(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions) attentions = list(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
outputs.append(attentions) outputs.append(attentions)
...@@ -860,14 +855,14 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -860,14 +855,14 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
bsz, tgt_len = shape_list(inputs_embeds)[:2] bsz, tgt_len = shape_list(inputs_embeds)[:2]
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
[input_ids, mems, head_mask, inputs_embeds, output_attentions, output_hidden_states], training=training input_ids, mems, head_mask, inputs_embeds, output_attentions, output_hidden_states, training=training
) )
last_hidden = transformer_outputs[0] last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:] pred_hid = last_hidden[:, -tgt_len:]
outputs = transformer_outputs[1:] outputs = transformer_outputs[1:]
softmax_output = self.crit([pred_hid, labels], training=training) softmax_output = self.crit(pred_hid, labels, training=training)
outputs = [softmax_output] + outputs outputs = [softmax_output] + outputs
return outputs # logits, new_mems, (all hidden states), (all attentions) return outputs # logits, new_mems, (all hidden states), (all attentions)
......
...@@ -114,8 +114,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer): ...@@ -114,8 +114,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
idx = tf.stack([r, target], 1) idx = tf.stack([r, target], 1)
return tf.gather_nd(logprob, idx) return tf.gather_nd(logprob, idx)
def call(self, inputs, return_mean=True, training=False): def call(self, hidden, target, return_mean=True, training=False):
hidden, target = inputs
head_logprob = 0 head_logprob = 0
if self.n_clusters == 0: if self.n_clusters == 0:
output = self._logit(hidden, self.out_layers[0][0], self.out_layers[0][1], self.out_projs[0]) output = self._logit(hidden, self.out_layers[0][0], self.out_layers[0][1], self.out_projs[0])
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import functools import functools
import logging import logging
import os import os
import warnings
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import h5py import h5py
...@@ -173,7 +174,11 @@ class TFTokenClassificationLoss: ...@@ -173,7 +174,11 @@ class TFTokenClassificationLoss:
) )
# make sure only labels that are not equal to -100 # make sure only labels that are not equal to -100
# are taken into account as loss # are taken into account as loss
active_loss = tf.reshape(labels, (-1,)) != -100 if tf.math.reduce_any(labels == -1).numpy() is True:
warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
active_loss = tf.reshape(labels, (-1,)) != -1
else:
active_loss = tf.reshape(labels, (-1,)) != -100
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
...@@ -233,7 +238,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -233,7 +238,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
@property @property
def dummy_inputs(self) -> Dict[str, tf.Tensor]: def dummy_inputs(self) -> Dict[str, tf.Tensor]:
""" """
:obj:`Dict[str, tf.Tensor]`: Dummy inputs to build the network. Dummy inputs to build the network.
Returns:
:obj:`Dict[str, tf.Tensor]`: The dummy inputs.
""" """
return {"input_ids": tf.constant(DUMMY_INPUTS)} return {"input_ids": tf.constant(DUMMY_INPUTS)}
...@@ -774,14 +782,16 @@ class TFSharedEmbeddings(tf.keras.layers.Layer): ...@@ -774,14 +782,16 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
return tf.gather(self.weight, input_ids) return tf.gather(self.weight, input_ids)
def _linear(self, inputs): def _linear(self, inputs):
"""Computes logits by running inputs through a linear layer.
Args:
inputs: A float32 tensor with shape [..., hidden_size]
Returns:
float32 tensor with shape [..., vocab_size].
""" """
first_dims = shape_list(inputs)[:-1] Computes logits by running inputs through a linear layer.
Args:
inputs: A float32 tensor with shape [..., hidden_size]
Returns:
float32 tensor with shape [..., vocab_size].
"""
first_dims = shape_list(inputs)[:-1]
x = tf.reshape(inputs, [-1, self.hidden_size]) x = tf.reshape(inputs, [-1, self.hidden_size])
logits = tf.matmul(x, self.weight, transpose_b=True) logits = tf.matmul(x, self.weight, transpose_b=True)
...@@ -789,7 +799,7 @@ class TFSharedEmbeddings(tf.keras.layers.Layer): ...@@ -789,7 +799,7 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
class TFSequenceSummary(tf.keras.layers.Layer): class TFSequenceSummary(tf.keras.layers.Layer):
r""" """
Compute a single vector summary of a sequence hidden states. Compute a single vector summary of a sequence hidden states.
Args: Args:
...@@ -852,26 +862,9 @@ class TFSequenceSummary(tf.keras.layers.Layer): ...@@ -852,26 +862,9 @@ class TFSequenceSummary(tf.keras.layers.Layer):
if self.has_last_dropout: if self.has_last_dropout:
self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout) self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
def call(self, inputs, training=False) -> tf.Tensor: def call(self, inputs, cls_index=None, training=False):
"""
Compute a single vector summary of a sequence hidden states.
Args:
inputs (:obj:`Union[tf.Tensor, Tuple[tf.Tensor], List[tf.Tensor], Dict[str, tf.Tensor]]`):
One or two tensors representing:
- **hidden_states** (:obj:`tf.Tensor` of shape :obj:`[batch_size, seq_len, hidden_size]`) -- The hidden
states of the last layer.
- **cls_index** :obj:`tf.Tensor` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are
optional leading dimensions of :obj:`hidden_states`. Used if :obj:`summary_type == "cls_index"` and
takes the last token of the sequence as classification token.
Returns:
:obj:`tf.Tensor`: The summary of the sequence hidden states.
"""
if not isinstance(inputs, (dict, tuple, list)): if not isinstance(inputs, (dict, tuple, list)):
hidden_states = inputs hidden_states = inputs
cls_index = None
elif isinstance(inputs, (tuple, list)): elif isinstance(inputs, (tuple, list)):
hidden_states = inputs[0] hidden_states = inputs[0]
cls_index = inputs[1] if len(inputs) > 1 else None cls_index = inputs[1] if len(inputs) > 1 else None
......
...@@ -39,7 +39,6 @@ from .modeling_tf_utils import ( ...@@ -39,7 +39,6 @@ from .modeling_tf_utils import (
TFSequenceSummary, TFSequenceSummary,
TFSharedEmbeddings, TFSharedEmbeddings,
TFTokenClassificationLoss, TFTokenClassificationLoss,
cast_bool_to_primitive,
get_initializer, get_initializer,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -123,6 +122,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -123,6 +122,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
self.layer_id = next(TFMultiHeadAttention.NEW_ID) self.layer_id = next(TFMultiHeadAttention.NEW_ID)
self.dim = dim self.dim = dim
self.n_heads = n_heads self.n_heads = n_heads
self.output_attentions = config.output_attentions
assert self.dim % self.n_heads == 0 assert self.dim % self.n_heads == 0
self.q_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="q_lin") self.q_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="q_lin")
...@@ -135,11 +135,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -135,11 +135,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
def prune_heads(self, heads): def prune_heads(self, heads):
raise NotImplementedError raise NotImplementedError
def call(self, inputs, training=False): def call(self, input, mask, kv, cache, head_mask, output_attentions, training=False):
""" """
Self-attention (if kv is None) or attention over source sentence (provided by kv). Self-attention (if kv is None) or attention over source sentence (provided by kv).
""" """
input, mask, kv, cache, head_mask, output_attentions = inputs
# Input is (bs, qlen, dim) # Input is (bs, qlen, dim)
# Mask is (bs, klen) (non-causal) or (bs, klen, klen) # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
bs, qlen, dim = shape_list(input) bs, qlen, dim = shape_list(input)
...@@ -196,7 +195,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -196,7 +195,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
context = unshape(context) # (bs, qlen, dim) context = unshape(context) # (bs, qlen, dim)
outputs = (self.out_lin(context),) outputs = (self.out_lin(context),)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
outputs = outputs + (weights,) outputs = outputs + (weights,)
return outputs return outputs
...@@ -445,6 +444,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -445,6 +444,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
inputs_embeds = self.embeddings(input_ids) inputs_embeds = self.embeddings(input_ids)
tensor = inputs_embeds + self.position_embeddings(position_ids) tensor = inputs_embeds + self.position_embeddings(position_ids)
if langs is not None and self.use_lang_emb and self.n_langs > 1: if langs is not None and self.use_lang_emb and self.n_langs > 1:
tensor = tensor + self.lang_embeddings(langs) tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None: if token_type_ids is not None:
...@@ -457,15 +457,15 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -457,15 +457,15 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
hidden_states = () hidden_states = ()
attentions = () attentions = ()
for i in range(self.n_layers): for i in range(self.n_layers):
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
hidden_states = hidden_states + (tensor,) hidden_states = hidden_states + (tensor,)
# self attention # self attention
attn_outputs = self.attentions[i]( attn_outputs = self.attentions[i](
[tensor, attn_mask, None, cache, head_mask[i], output_attentions], training=training tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training) attn = self.dropout(attn, training=training)
tensor = tensor + attn tensor = tensor + attn
...@@ -484,7 +484,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -484,7 +484,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
tensor = tensor * mask[..., tf.newaxis] tensor = tensor * mask[..., tf.newaxis]
# Add last hidden state # Add last hidden state
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
hidden_states = hidden_states + (tensor,) hidden_states = hidden_states + (tensor,)
# update cache length # update cache length
...@@ -495,9 +495,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -495,9 +495,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# tensor = tensor.transpose(0, 1) # tensor = tensor.transpose(0, 1)
outputs = (tensor,) outputs = (tensor,)
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
outputs = outputs + (hidden_states,) outputs = outputs + (hidden_states,)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
outputs = outputs + (attentions,) outputs = outputs + (attentions,)
return outputs # outputs, (hidden_states), (attentions) return outputs # outputs, (hidden_states), (attentions)
...@@ -930,7 +930,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -930,7 +930,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, inputs_embeds.shape[-2], inputs_embeds.shape[-1])) tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None if inputs_embeds is not None
else None else None
) )
...@@ -943,7 +943,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -943,7 +943,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
) )
lengths = None lengths = None
flat_inputs = [ transformer_outputs = self.transformer(
flat_input_ids, flat_input_ids,
flat_attention_mask, flat_attention_mask,
flat_langs, flat_langs,
...@@ -955,14 +955,12 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -955,14 +955,12 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
flat_inputs_embeds, flat_inputs_embeds,
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
] training=training,
)
transformer_outputs = self.transformer(flat_inputs, training=training)
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.sequence_summary(output) logits = self.sequence_summary(output)
logits = self.logits_proj(logits) logits = self.logits_proj(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
outputs = (reshaped_logits,) + transformer_outputs[1:] # add hidden states and attention if they are here outputs = (reshaped_logits,) + transformer_outputs[1:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
......
...@@ -38,7 +38,6 @@ from .modeling_tf_utils import ( ...@@ -38,7 +38,6 @@ from .modeling_tf_utils import (
TFSequenceSummary, TFSequenceSummary,
TFSharedEmbeddings, TFSharedEmbeddings,
TFTokenClassificationLoss, TFTokenClassificationLoss,
cast_bool_to_primitive,
get_initializer, get_initializer,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -92,6 +91,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -92,6 +91,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
self.d_model = config.d_model self.d_model = config.d_model
self.scale = 1 / (config.d_head ** 0.5) self.scale = 1 / (config.d_head ** 0.5)
self.initializer_range = config.initializer_range self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
...@@ -142,11 +142,10 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -142,11 +142,10 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
return x return x
def rel_attn_core(self, inputs, training=False): def rel_attn_core(
self, q_head, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask, head_mask, output_attentions, training=False
):
"""Core relative positional attention operations.""" """Core relative positional attention operations."""
q_head, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask, head_mask, output_attentions = inputs
# content based attention score # content based attention score
ac = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_w_bias, k_head_h) ac = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_w_bias, k_head_h)
...@@ -182,16 +181,14 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -182,16 +181,14 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# attention output # attention output
attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h) attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
return attn_vec, attn_prob return attn_vec, attn_prob
return attn_vec return attn_vec
def post_attention(self, inputs, residual=True, training=False): def post_attention(self, h, attn_vec, residual=True, training=False):
"""Post-attention processing.""" """Post-attention processing."""
# post-attention projection (back to `d_model`) # post-attention projection (back to `d_model`)
h, attn_vec = inputs
attn_out = tf.einsum("ibnd,hnd->ibh", attn_vec, self.o) attn_out = tf.einsum("ibnd,hnd->ibh", attn_vec, self.o)
attn_out = self.dropout(attn_out, training=training) attn_out = self.dropout(attn_out, training=training)
...@@ -202,9 +199,20 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -202,9 +199,20 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
return output return output
def call(self, inputs, training=False): def call(
(h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask, output_attentions) = inputs self,
h,
g,
attn_mask_h,
attn_mask_g,
r,
seg_mat,
mems,
target_mapping,
head_mask,
output_attentions,
training=False,
):
if g is not None: if g is not None:
# Two-stream attention with relative positional encoding. # Two-stream attention with relative positional encoding.
# content based attention score # content based attention score
...@@ -228,15 +236,22 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -228,15 +236,22 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# core attention ops # core attention ops
attn_vec_h = self.rel_attn_core( attn_vec_h = self.rel_attn_core(
[q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask, output_attentions], q_head_h,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_h,
head_mask,
output_attentions,
training=training, training=training,
) )
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
attn_vec_h, attn_prob_h = attn_vec_h attn_vec_h, attn_prob_h = attn_vec_h
# post processing # post processing
output_h = self.post_attention([h, attn_vec_h], training=training) output_h = self.post_attention(h, attn_vec_h, training=training)
# g-stream # g-stream
# query-stream query head # query-stream query head
...@@ -246,27 +261,41 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -246,27 +261,41 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
if target_mapping is not None: if target_mapping is not None:
q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping) q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
attn_vec_g = self.rel_attn_core( attn_vec_g = self.rel_attn_core(
[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask, output_attentions], q_head_g,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_g,
head_mask,
output_attentions,
training=training, training=training,
) )
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g attn_vec_g, attn_prob_g = attn_vec_g
attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping) attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
else: else:
attn_vec_g = self.rel_attn_core( attn_vec_g = self.rel_attn_core(
[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask, output_attentions], q_head_g,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_g,
head_mask,
output_attentions,
training=training, training=training,
) )
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g attn_vec_g, attn_prob_g = attn_vec_g
# post processing # post processing
output_g = self.post_attention([g, attn_vec_g], training=training) output_g = self.post_attention(g, attn_vec_g, training=training)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
attn_prob = attn_prob_h, attn_prob_g attn_prob = attn_prob_h, attn_prob_g
else: else:
...@@ -286,19 +315,26 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -286,19 +315,26 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# core attention ops # core attention ops
attn_vec = self.rel_attn_core( attn_vec = self.rel_attn_core(
[q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask, output_attentions], q_head_h,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_h,
head_mask,
output_attentions,
training=training, training=training,
) )
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
attn_vec, attn_prob = attn_vec attn_vec, attn_prob = attn_vec
# post processing # post processing
output_h = self.post_attention([h, attn_vec], training=training) output_h = self.post_attention(h, attn_vec, training=training)
output_g = None output_g = None
outputs = (output_h, output_g) outputs = (output_h, output_g)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
outputs = outputs + (attn_prob,) outputs = outputs + (attn_prob,)
return outputs return outputs
...@@ -337,8 +373,33 @@ class TFXLNetLayer(tf.keras.layers.Layer): ...@@ -337,8 +373,33 @@ class TFXLNetLayer(tf.keras.layers.Layer):
self.ff = TFXLNetFeedForward(config, name="ff") self.ff = TFXLNetFeedForward(config, name="ff")
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
def call(self, inputs, training=False): def call(
outputs = self.rel_attn(inputs, training=training) self,
output_h,
output_g,
non_tgt_mask,
attn_mask,
pos_emb,
seg_mat,
mems,
target_mapping,
head_mask,
output_attentions,
training=False,
):
outputs = self.rel_attn(
output_h,
output_g,
non_tgt_mask,
attn_mask,
pos_emb,
seg_mat,
mems,
target_mapping,
head_mask,
output_attentions,
training=training,
)
output_h, output_g = outputs[:2] output_h, output_g = outputs[:2]
if output_g is not None: if output_g is not None:
...@@ -686,32 +747,30 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -686,32 +747,30 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
hidden_states = [] hidden_states = []
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
# cache new mems # cache new mems
if self.mem_len is not None and self.mem_len > 0 and use_cache is True: if self.mem_len is not None and self.mem_len > 0 and use_cache:
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h) hidden_states.append((output_h, output_g) if output_g is not None else output_h)
outputs = layer_module( outputs = layer_module(
[ output_h,
output_h, output_g,
output_g, non_tgt_mask,
non_tgt_mask, attn_mask,
attn_mask, pos_emb,
pos_emb, seg_mat,
seg_mat, mems[i],
mems[i], target_mapping,
target_mapping, head_mask[i],
head_mask[i], output_attentions,
output_attentions,
],
training=training, training=training,
) )
output_h, output_g = outputs[:2] output_h, output_g = outputs[:2]
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
attentions.append(outputs[2]) attentions.append(outputs[2])
# Add last hidden state # Add last hidden state
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h) hidden_states.append((output_h, output_g) if output_g is not None else output_h)
output = self.dropout(output_g if output_g is not None else output_h, training=training) output = self.dropout(output_g if output_g is not None else output_h, training=training)
...@@ -719,16 +778,16 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -719,16 +778,16 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
outputs = (tf.transpose(output, perm=(1, 0, 2)),) outputs = (tf.transpose(output, perm=(1, 0, 2)),)
if self.mem_len is not None and self.mem_len > 0 and use_cache is True: if self.mem_len is not None and self.mem_len > 0 and use_cache:
outputs = outputs + (new_mems,) outputs = outputs + (new_mems,)
if cast_bool_to_primitive(output_hidden_states) is True: if output_hidden_states:
if output_g is not None: if output_g is not None:
hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs) hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs)
else: else:
hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states) hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states)
outputs = outputs + (hidden_states,) outputs = outputs + (hidden_states,)
if cast_bool_to_primitive(output_attentions) is True: if output_attentions:
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions) attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
outputs = outputs + (attentions,) outputs = outputs + (attentions,)
...@@ -1240,8 +1299,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1240,8 +1299,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
if inputs_embeds is not None if inputs_embeds is not None
else None else None
) )
transformer_outputs = self.transformer(
flat_inputs = [
flat_input_ids, flat_input_ids,
flat_attention_mask, flat_attention_mask,
mems, mems,
...@@ -1254,14 +1312,12 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1254,14 +1312,12 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
use_cache, use_cache,
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
] training=training,
)
transformer_outputs = self.transformer(flat_inputs, training=training)
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.sequence_summary(output) logits = self.sequence_summary(output)
logits = self.logits_proj(logits) logits = self.logits_proj(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
outputs = (reshaped_logits,) + transformer_outputs[1:] # add hidden states and attention if they are here outputs = (reshaped_logits,) + transformer_outputs[1:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
......
...@@ -4,7 +4,6 @@ import datetime ...@@ -4,7 +4,6 @@ import datetime
import logging import logging
import math import math
import os import os
import sys
import warnings import warnings
from typing import Callable, Dict, Optional, Tuple from typing import Callable, Dict, Optional, Tuple
...@@ -25,15 +24,6 @@ if is_wandb_available(): ...@@ -25,15 +24,6 @@ if is_wandb_available():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if parse(tf.__version__).release < (2, 2, 0):
logger.info(
"You need to run the TensorFlow trainer with at least the version 2.2.0, your version is {}".format(
tf.__version__
)
)
sys.exit(1)
class TFTrainer: class TFTrainer:
""" """
TFTrainer is a simple but feature-complete training and eval loop for TensorFlow, TFTrainer is a simple but feature-complete training and eval loop for TensorFlow,
...@@ -77,6 +67,11 @@ class TFTrainer: ...@@ -77,6 +67,11 @@ class TFTrainer:
None, None,
), ),
): ):
assert parse(tf.__version__).release >= (2, 2, 0), (
"You need to run the TensorFlow trainer with at least the version 2.2.0, your version is %r "
% tf.__version__
)
self.model = model self.model = model
self.args = args self.args = args
self.train_dataset = train_dataset self.train_dataset = train_dataset
......
...@@ -23,7 +23,7 @@ import unittest ...@@ -23,7 +23,7 @@ import unittest
from importlib import import_module from importlib import import_module
from transformers import is_tf_available, is_torch_available from transformers import is_tf_available, is_torch_available
from transformers.testing_utils import _tf_gpu_memory_limit, require_tf from transformers.testing_utils import _tf_gpu_memory_limit, require_tf, slow
if is_tf_available(): if is_tf_available():
...@@ -130,6 +130,61 @@ class TFModelTesterMixin: ...@@ -130,6 +130,61 @@ class TFModelTesterMixin:
self.assert_outputs_same(after_outputs, outputs) self.assert_outputs_same(after_outputs, outputs)
@slow
def test_saved_model_with_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
for model_class in self.all_model_classes:
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
num_out = len(model(inputs_dict))
model._saved_model_inputs_spec = None
model._set_save_spec(inputs_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
tf.saved_model.save(model, tmpdirname)
model = tf.keras.models.load_model(tmpdirname)
outputs = model(inputs_dict)
hidden_states = [t.numpy() for t in outputs[-1]]
self.assertEqual(len(outputs), num_out)
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
self.assertListEqual(
list(hidden_states[0].shape[-2:]), [self.model_tester.seq_length, self.model_tester.hidden_size],
)
@slow
def test_saved_model_with_attentions_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_attentions = True
encoder_seq_length = (
self.model_tester.encoder_seq_length
if hasattr(self.model_tester, "encoder_seq_length")
else self.model_tester.seq_length
)
encoder_key_length = (
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else encoder_seq_length
)
for model_class in self.all_model_classes:
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
num_out = len(model(inputs_dict))
model._saved_model_inputs_spec = None
model._set_save_spec(inputs_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
tf.saved_model.save(model, tmpdirname)
model = tf.keras.models.load_model(tmpdirname)
outputs = model(inputs_dict)
attentions = [t.numpy() for t in outputs[-1]]
self.assertEqual(len(outputs), num_out)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def test_keras_save_load(self): def test_keras_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -342,11 +342,17 @@ class TFXLNetModelTester: ...@@ -342,11 +342,17 @@ class TFXLNetModelTester:
"attention_mask": multiple_choice_input_mask, "attention_mask": multiple_choice_input_mask,
"token_type_ids": multiple_choice_token_type_ids, "token_type_ids": multiple_choice_token_type_ids,
} }
(logits,) = model(inputs) (logits, mems_1) = model(inputs)
result = { result = {
"mems_1": [mem.numpy() for mem in mems_1],
"logits": logits.numpy(), "logits": logits.numpy(),
} }
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices]) self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size * self.num_choices, self.hidden_size]] * self.num_hidden_layers,
)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment