Unverified Commit d161ed16 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Update the TF models to remove their interdependencies (#7238)

* Refacto the models to remove their interdependencies

* Fix Flaubert model

* Fix Flaubert

* Fix XLM

* Fix Albert

* Fix Roberta

* Fix Albert

* Fix Flaubert

* Apply style + remove unused imports

* Fix Distilbert

* remove unused import

* fix Distilbert

* Fix Flaubert

* Apply style

* Fix Flaubert

* Add the copy comments for the check_copies script

* Fix MobileBert model name

* Address Morgan's comments

* Fix typo

* Oops typo
parent 0cffa424
...@@ -31,7 +31,6 @@ from .file_utils import ( ...@@ -31,7 +31,6 @@ from .file_utils import (
add_start_docstrings_to_callable, add_start_docstrings_to_callable,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_tf_bert import TFBertSelfAttention
from .modeling_tf_outputs import ( from .modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPooling, TFBaseModelOutputWithPooling,
...@@ -181,82 +180,6 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer): ...@@ -181,82 +180,6 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
return tf.reshape(logits, [batch_size, length, self.config.vocab_size]) return tf.reshape(logits, [batch_size, length, self.config.vocab_size])
class TFAlbertSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
assert (
config.hidden_size % config.num_attention_heads == 0
), f"Hidden size {config.hidden_size} not dividable by number of heads {config.num_attention_heads}"
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.output_attentions = config.output_attentions
self.query = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
)
self.key = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
)
self.value = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# Take the dot product between "query" and "key" to get the raw attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
# scale attention_scores
dk = tf.cast(shape_list(key_layer)[-1], tf.float32)
attention_scores = attention_scores / tf.math.sqrt(dk)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFAlbertModel call() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs, training=training)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
context_layer = tf.reshape(
context_layer, (batch_size, -1, self.all_head_size)
) # (batch_size, seq_len_q, all_head_size)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class TFAlbertSelfOutput(tf.keras.layers.Layer): class TFAlbertSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -273,7 +196,7 @@ class TFAlbertSelfOutput(tf.keras.layers.Layer): ...@@ -273,7 +196,7 @@ class TFAlbertSelfOutput(tf.keras.layers.Layer):
return hidden_states return hidden_states
class TFAlbertAttention(TFBertSelfAttention): class TFAlbertAttention(tf.keras.layers.Layer):
""" Contains the complete attention sublayer, including both dropouts and layer norm. """ """ Contains the complete attention sublayer, including both dropouts and layer norm. """
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
...@@ -281,6 +204,19 @@ class TFAlbertAttention(TFBertSelfAttention): ...@@ -281,6 +204,19 @@ class TFAlbertAttention(TFBertSelfAttention):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads
assert config.hidden_size % config.num_attention_heads == 0
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
)
self.key = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
)
self.value = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
) )
...@@ -290,6 +226,11 @@ class TFAlbertAttention(TFBertSelfAttention): ...@@ -290,6 +226,11 @@ class TFAlbertAttention(TFBertSelfAttention):
self.attention_dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) self.attention_dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
self.output_dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.output_dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def transpose_for_scores(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3])
def prune_heads(self, heads): def prune_heads(self, heads):
raise NotImplementedError raise NotImplementedError
...@@ -342,6 +283,7 @@ class TFAlbertAttention(TFBertSelfAttention): ...@@ -342,6 +283,7 @@ class TFAlbertAttention(TFBertSelfAttention):
# add attentions if we output them # add attentions if we output them
outputs = (attention_output,) + self_outputs[1:] outputs = (attention_output,) + self_outputs[1:]
return outputs return outputs
......
...@@ -93,6 +93,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -93,6 +93,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.initializer_range = config.initializer_range self.initializer_range = config.initializer_range
...@@ -124,6 +125,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -124,6 +125,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
shape=[self.vocab_size, self.hidden_size], shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
super().build(input_shape) super().build(input_shape)
def call( def call(
...@@ -273,6 +275,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -273,6 +275,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
class TFBertSelfOutput(tf.keras.layers.Layer): class TFBertSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
) )
...@@ -290,6 +293,7 @@ class TFBertSelfOutput(tf.keras.layers.Layer): ...@@ -290,6 +293,7 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
class TFBertAttention(tf.keras.layers.Layer): class TFBertAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.self_attention = TFBertSelfAttention(config, name="self") self.self_attention = TFBertSelfAttention(config, name="self")
self.dense_output = TFBertSelfOutput(config, name="output") self.dense_output = TFBertSelfOutput(config, name="output")
...@@ -309,6 +313,7 @@ class TFBertAttention(tf.keras.layers.Layer): ...@@ -309,6 +313,7 @@ class TFBertAttention(tf.keras.layers.Layer):
class TFBertIntermediate(tf.keras.layers.Layer): class TFBertIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
) )
...@@ -328,6 +333,7 @@ class TFBertIntermediate(tf.keras.layers.Layer): ...@@ -328,6 +333,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
class TFBertOutput(tf.keras.layers.Layer): class TFBertOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
) )
...@@ -345,6 +351,7 @@ class TFBertOutput(tf.keras.layers.Layer): ...@@ -345,6 +351,7 @@ class TFBertOutput(tf.keras.layers.Layer):
class TFBertLayer(tf.keras.layers.Layer): class TFBertLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.attention = TFBertAttention(config, name="attention") self.attention = TFBertAttention(config, name="attention")
self.intermediate = TFBertIntermediate(config, name="intermediate") self.intermediate = TFBertIntermediate(config, name="intermediate")
self.bert_output = TFBertOutput(config, name="output") self.bert_output = TFBertOutput(config, name="output")
...@@ -364,6 +371,7 @@ class TFBertLayer(tf.keras.layers.Layer): ...@@ -364,6 +371,7 @@ class TFBertLayer(tf.keras.layers.Layer):
class TFBertEncoder(tf.keras.layers.Layer): class TFBertEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)] self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
def call( def call(
...@@ -397,6 +405,7 @@ class TFBertEncoder(tf.keras.layers.Layer): ...@@ -397,6 +405,7 @@ class TFBertEncoder(tf.keras.layers.Layer):
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput( return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
) )
...@@ -405,6 +414,7 @@ class TFBertEncoder(tf.keras.layers.Layer): ...@@ -405,6 +414,7 @@ class TFBertEncoder(tf.keras.layers.Layer):
class TFBertPooler(tf.keras.layers.Layer): class TFBertPooler(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.hidden_size, config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range), kernel_initializer=get_initializer(config.initializer_range),
...@@ -424,6 +434,7 @@ class TFBertPooler(tf.keras.layers.Layer): ...@@ -424,6 +434,7 @@ class TFBertPooler(tf.keras.layers.Layer):
class TFBertPredictionHeadTransform(tf.keras.layers.Layer): class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
) )
...@@ -446,6 +457,7 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer): ...@@ -446,6 +457,7 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
class TFBertLMPredictionHead(tf.keras.layers.Layer): class TFBertLMPredictionHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs): def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.transform = TFBertPredictionHeadTransform(config, name="transform") self.transform = TFBertPredictionHeadTransform(config, name="transform")
...@@ -455,6 +467,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer): ...@@ -455,6 +467,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape) super().build(input_shape)
def call(self, hidden_states): def call(self, hidden_states):
...@@ -468,6 +481,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer): ...@@ -468,6 +481,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
class TFBertMLMHead(tf.keras.layers.Layer): class TFBertMLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs): def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.predictions = TFBertLMPredictionHead(config, input_embeddings, name="predictions") self.predictions = TFBertLMPredictionHead(config, input_embeddings, name="predictions")
def call(self, sequence_output): def call(self, sequence_output):
...@@ -479,6 +493,7 @@ class TFBertMLMHead(tf.keras.layers.Layer): ...@@ -479,6 +493,7 @@ class TFBertMLMHead(tf.keras.layers.Layer):
class TFBertNSPHead(tf.keras.layers.Layer): class TFBertNSPHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.seq_relationship = tf.keras.layers.Dense( self.seq_relationship = tf.keras.layers.Dense(
2, kernel_initializer=get_initializer(config.initializer_range), name="seq_relationship" 2, kernel_initializer=get_initializer(config.initializer_range), name="seq_relationship"
) )
...@@ -495,6 +510,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -495,6 +510,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.initializer_range = config.initializer_range self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
...@@ -571,6 +587,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -571,6 +587,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.fill(input_shape, 1) attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0) token_type_ids = tf.fill(input_shape, 0)
...@@ -588,7 +605,6 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -588,7 +605,6 @@ class TFBertMainLayer(tf.keras.layers.Layer):
# positions we want to attend and -10000.0 for masked positions. # positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype) extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
...@@ -767,6 +783,7 @@ BERT_INPUTS_DOCSTRING = r""" ...@@ -767,6 +783,7 @@ BERT_INPUTS_DOCSTRING = r"""
class TFBertModel(TFBertPreTrainedModel): class TFBertModel(TFBertPreTrainedModel):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name="bert") self.bert = TFBertMainLayer(config, name="bert")
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
...@@ -778,6 +795,7 @@ class TFBertModel(TFBertPreTrainedModel): ...@@ -778,6 +795,7 @@ class TFBertModel(TFBertPreTrainedModel):
) )
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs) outputs = self.bert(inputs, **kwargs)
return outputs return outputs
...@@ -818,7 +836,6 @@ class TFBertForPreTraining(TFBertPreTrainedModel): ...@@ -818,7 +836,6 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
return_dict = kwargs.get("return_dict") return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.bert.return_dict return_dict = return_dict if return_dict is not None else self.bert.return_dict
outputs = self.bert(inputs, **kwargs) outputs = self.bert(inputs, **kwargs)
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False)) prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
seq_relationship_score = self.nsp(pooled_output) seq_relationship_score = self.nsp(pooled_output)
...@@ -880,6 +897,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -880,6 +897,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
""" """
return_dict = return_dict if return_dict is not None else self.bert.return_dict return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9: if len(inputs) > 9:
...@@ -902,7 +920,6 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -902,7 +920,6 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=training) prediction_scores = self.mlm(sequence_output, training=training)
loss = None if labels is None else self.compute_loss(labels, prediction_scores) loss = None if labels is None else self.compute_loss(labels, prediction_scores)
if not return_dict: if not return_dict:
...@@ -956,6 +973,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -956,6 +973,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
Indices should be in ``[0, ..., config.vocab_size - 1]``. Indices should be in ``[0, ..., config.vocab_size - 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.bert.return_dict return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9: if len(inputs) > 9:
...@@ -978,8 +996,8 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -978,8 +996,8 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.mlm(sequence_output, training=training) logits = self.mlm(sequence_output, training=training)
loss = None loss = None
if labels is not None: if labels is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
logits = logits[:, :-1] logits = logits[:, :-1]
...@@ -1033,7 +1051,6 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel): ...@@ -1033,7 +1051,6 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
return_dict = kwargs.get("return_dict") return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.bert.return_dict return_dict = return_dict if return_dict is not None else self.bert.return_dict
outputs = self.bert(inputs, **kwargs) outputs = self.bert(inputs, **kwargs)
pooled_output = outputs[1] pooled_output = outputs[1]
seq_relationship_score = self.nsp(pooled_output) seq_relationship_score = self.nsp(pooled_output)
...@@ -1055,8 +1072,8 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel): ...@@ -1055,8 +1072,8 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss): class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name="bert") self.bert = TFBertMainLayer(config, name="bert")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense( self.classifier = tf.keras.layers.Dense(
...@@ -1092,6 +1109,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific ...@@ -1092,6 +1109,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.bert.return_dict return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9: if len(inputs) > 9:
...@@ -1113,10 +1131,8 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific ...@@ -1113,10 +1131,8 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
) )
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training) pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if labels is None else self.compute_loss(labels, logits)
if not return_dict: if not return_dict:
...@@ -1208,6 +1224,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1208,6 +1224,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
assert len(inputs) <= 10, "Too many inputs." assert len(inputs) <= 10, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
return_dict = return_dict if return_dict is not None else self.bert.return_dict return_dict = return_dict if return_dict is not None else self.bert.return_dict
if input_ids is not None: if input_ids is not None:
...@@ -1242,7 +1259,6 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1242,7 +1259,6 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
pooled_output = self.dropout(pooled_output, training=training) pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits) loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
if not return_dict: if not return_dict:
...@@ -1265,8 +1281,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1265,8 +1281,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss): class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name="bert") self.bert = TFBertMainLayer(config, name="bert")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense( self.classifier = tf.keras.layers.Dense(
...@@ -1300,6 +1316,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -1300,6 +1316,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
Indices should be in ``[0, ..., config.num_labels - 1]``. Indices should be in ``[0, ..., config.num_labels - 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.bert.return_dict return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9: if len(inputs) > 9:
...@@ -1319,12 +1336,9 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -1319,12 +1336,9 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training) sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if labels is None else self.compute_loss(labels, logits)
if not return_dict: if not return_dict:
...@@ -1347,8 +1361,8 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -1347,8 +1361,8 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss): class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name="bert") self.bert = TFBertMainLayer(config, name="bert")
self.qa_outputs = tf.keras.layers.Dense( self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
...@@ -1387,6 +1401,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -1387,6 +1401,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
Position outside of the sequence are not taken into account for computing the loss. Position outside of the sequence are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.bert.return_dict return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
start_positions = inputs[9] if len(inputs) > 9 else start_positions start_positions = inputs[9] if len(inputs) > 9 else start_positions
end_positions = inputs[10] if len(inputs) > 10 else end_positions end_positions = inputs[10] if len(inputs) > 10 else end_positions
...@@ -1408,15 +1423,13 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -1408,15 +1423,13 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = tf.split(logits, 2, axis=-1) start_logits, end_logits = tf.split(logits, 2, axis=-1)
start_logits = tf.squeeze(start_logits, axis=-1) start_logits = tf.squeeze(start_logits, axis=-1)
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions} labels = {"start_position": start_positions}
labels["end_position"] = end_positions labels["end_position"] = end_positions
......
...@@ -16,8 +16,6 @@ ...@@ -16,8 +16,6 @@
""" """
import math
import tensorflow as tf import tensorflow as tf
from .activations_tf import get_tf_activation from .activations_tf import get_tf_activation
...@@ -217,9 +215,8 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): ...@@ -217,9 +215,8 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
k_length = shape_list(key)[1] k_length = shape_list(key)[1]
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim) # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
# assert key.size() == value.size() # assert key.size() == value.size()
dim_per_head = tf.math.divide(self.dim, self.n_heads)
dim_per_head = self.dim // self.n_heads dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
mask_reshape = [bs, 1, 1, k_length] mask_reshape = [bs, 1, 1, k_length]
def shape(x): def shape(x):
...@@ -233,17 +230,16 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): ...@@ -233,17 +230,16 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
q = tf.cast(q, dtype=tf.float32)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32)))
k = tf.cast(k, dtype=q.dtype)
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, q_length, k_length) scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, q_length, k_length)
mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen) mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, q_length, k_length) # scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, q_length, k_length)
scores_dtype = scores.dtype mask = tf.cast(mask, dtype=scores.dtype)
# calculate `scores` in `tf.float32` to avoid numeric overflow scores = scores - 1e30 * (1.0 - mask)
scores = tf.cast(scores, dtype=tf.float32) - 1e30 * (1.0 - tf.cast(mask, dtype=tf.float32)) weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = tf.cast(tf.nn.softmax(scores, axis=-1), dtype=scores_dtype) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen) weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
# Mask heads if we want to # Mask heads if we want to
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -31,7 +31,6 @@ from .file_utils import ( ...@@ -31,7 +31,6 @@ from .file_utils import (
add_start_docstrings_to_callable, add_start_docstrings_to_callable,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_tf_bert import TFBertIntermediate
from .modeling_tf_outputs import ( from .modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPooling, TFBaseModelOutputWithPooling,
...@@ -63,11 +62,29 @@ _CONFIG_FOR_DOC = "MobileBertConfig" ...@@ -63,11 +62,29 @@ _CONFIG_FOR_DOC = "MobileBertConfig"
_TOKENIZER_FOR_DOC = "MobileBertTokenizer" _TOKENIZER_FOR_DOC = "MobileBertTokenizer"
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"mobilebert-uncased", "google/mobilebert-uncased",
# See all MobileBERT models at https://huggingface.co/models?filter=mobilebert # See all MobileBERT models at https://huggingface.co/models?filter=mobilebert
] ]
class TFMobileBertIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.intermediate_size, name="dense")
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class TFLayerNorm(tf.keras.layers.LayerNormalization): class TFLayerNorm(tf.keras.layers.LayerNormalization):
def __init__(self, feat_size, *args, **kwargs): def __init__(self, feat_size, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -353,12 +370,6 @@ class TFMobileBertAttention(tf.keras.layers.Layer): ...@@ -353,12 +370,6 @@ class TFMobileBertAttention(tf.keras.layers.Layer):
return outputs return outputs
class TFMobileBertIntermediate(TFBertIntermediate):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.dense = tf.keras.layers.Dense(config.intermediate_size, name="dense")
class TFOutputBottleneck(tf.keras.layers.Layer): class TFOutputBottleneck(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
......
This diff is collapsed.
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import itertools import itertools
import math
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -114,13 +113,12 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32): ...@@ -114,13 +113,12 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
return mask, attn_mask return mask, attn_mask
class TFMultiHeadAttention(tf.keras.layers.Layer): class TFXLMMultiHeadAttention(tf.keras.layers.Layer):
NEW_ID = itertools.count() NEW_ID = itertools.count()
def __init__(self, n_heads, dim, config, **kwargs): def __init__(self, n_heads, dim, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.layer_id = next(TFMultiHeadAttention.NEW_ID) self.layer_id = next(TFXLMMultiHeadAttention.NEW_ID)
self.dim = dim self.dim = dim
self.n_heads = n_heads self.n_heads = n_heads
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
...@@ -143,13 +141,15 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -143,13 +141,15 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
# Input is (bs, qlen, dim) # Input is (bs, qlen, dim)
# Mask is (bs, klen) (non-causal) or (bs, klen, klen) # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
bs, qlen, dim = shape_list(input) bs, qlen, dim = shape_list(input)
if kv is None: if kv is None:
klen = qlen if cache is None else cache["slen"] + qlen klen = qlen if cache is None else cache["slen"] + qlen
else: else:
klen = shape_list(kv)[1] klen = shape_list(kv)[1]
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim) # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
n_heads = self.n_heads dim_per_head = tf.math.divide(self.dim, self.n_heads)
dim_per_head = self.dim // n_heads dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen) mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen)
def shape(x): def shape(x):
...@@ -161,6 +161,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -161,6 +161,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head)) return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))
q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head) q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
if kv is None: if kv is None:
k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head) k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head) v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head)
...@@ -177,14 +178,17 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -177,14 +178,17 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head) v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head)
else: else:
k, v = cache[self.layer_id] k, v = cache[self.layer_id]
cache[self.layer_id] = (k, v) cache[self.layer_id] = (k, v)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head) q = tf.cast(q, dtype=tf.float32)
q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32))) # (bs, n_heads, qlen, dim_per_head)
k = tf.cast(k, dtype=q.dtype)
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen) scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen) mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen) # scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
mask = tf.cast(mask, dtype=scores.dtype)
scores = scores - 1e30 * (1.0 - mask) scores = scores - 1e30 * (1.0 - mask)
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen) weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen) weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
...@@ -194,16 +198,18 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -194,16 +198,18 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
context = unshape(context) # (bs, qlen, dim) context = unshape(context) # (bs, qlen, dim)
outputs = (self.out_lin(context),) outputs = (self.out_lin(context),)
if output_attentions: if output_attentions:
outputs = outputs + (weights,) outputs = outputs + (weights,)
return outputs return outputs
class TFTransformerFFN(tf.keras.layers.Layer): class TFXLMTransformerFFN(tf.keras.layers.Layer):
def __init__(self, in_dim, dim_hidden, out_dim, config, **kwargs): def __init__(self, in_dim, dim_hidden, out_dim, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.lin1 = tf.keras.layers.Dense(dim_hidden, kernel_initializer=get_initializer(config.init_std), name="lin1") self.lin1 = tf.keras.layers.Dense(dim_hidden, kernel_initializer=get_initializer(config.init_std), name="lin1")
self.lin2 = tf.keras.layers.Dense(out_dim, kernel_initializer=get_initializer(config.init_std), name="lin2") self.lin2 = tf.keras.layers.Dense(out_dim, kernel_initializer=get_initializer(config.init_std), name="lin2")
self.act = get_tf_activation("gelu") if config.gelu_activation else get_tf_activation("relu") self.act = get_tf_activation("gelu") if config.gelu_activation else get_tf_activation("relu")
...@@ -214,6 +220,7 @@ class TFTransformerFFN(tf.keras.layers.Layer): ...@@ -214,6 +220,7 @@ class TFTransformerFFN(tf.keras.layers.Layer):
x = self.act(x) x = self.act(x)
x = self.lin2(x) x = self.lin2(x)
x = self.dropout(x, training=training) x = self.dropout(x, training=training)
return x return x
...@@ -223,6 +230,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -223,6 +230,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.return_dict = config.use_return_dict self.return_dict = config.use_return_dict
...@@ -230,8 +238,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -230,8 +238,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# encoder / decoder, output layer # encoder / decoder, output layer
self.is_encoder = config.is_encoder self.is_encoder = config.is_encoder
self.is_decoder = not config.is_encoder self.is_decoder = not config.is_encoder
if self.is_decoder: if self.is_decoder:
raise NotImplementedError("Currently XLM can only be used as an encoder") raise NotImplementedError("Currently XLM can only be used as an encoder")
# self.with_output = with_output # self.with_output = with_output
self.causal = config.causal self.causal = config.causal
...@@ -257,16 +267,17 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -257,16 +267,17 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# embeddings # embeddings
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
self.attention_dropout = tf.keras.layers.Dropout(config.attention_dropout) self.attention_dropout = tf.keras.layers.Dropout(config.attention_dropout)
self.position_embeddings = tf.keras.layers.Embedding( self.position_embeddings = tf.keras.layers.Embedding(
config.max_position_embeddings, config.max_position_embeddings,
self.dim, self.dim,
embeddings_initializer=get_initializer(config.embed_init_std), embeddings_initializer=get_initializer(config.embed_init_std),
name="position_embeddings", name="position_embeddings",
) )
if config.sinusoidal_embeddings: if config.sinusoidal_embeddings:
raise NotImplementedError raise NotImplementedError
# create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight) # create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
if config.n_langs > 1 and config.use_lang_emb: if config.n_langs > 1 and config.use_lang_emb:
self.lang_embeddings = tf.keras.layers.Embedding( self.lang_embeddings = tf.keras.layers.Embedding(
self.n_langs, self.n_langs,
...@@ -274,6 +285,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -274,6 +285,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
embeddings_initializer=get_initializer(config.embed_init_std), embeddings_initializer=get_initializer(config.embed_init_std),
name="lang_embeddings", name="lang_embeddings",
) )
self.embeddings = TFSharedEmbeddings( self.embeddings = TFSharedEmbeddings(
self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings" self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings"
) # padding_idx=self.pad_index) ) # padding_idx=self.pad_index)
...@@ -290,7 +302,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -290,7 +302,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
for i in range(self.n_layers): for i in range(self.n_layers):
self.attentions.append( self.attentions.append(
TFMultiHeadAttention(self.n_heads, self.dim, config=config, name="attentions_._{}".format(i)) TFXLMMultiHeadAttention(self.n_heads, self.dim, config=config, name="attentions_._{}".format(i))
) )
self.layer_norm1.append( self.layer_norm1.append(
tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1_._{}".format(i)) tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1_._{}".format(i))
...@@ -299,7 +311,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -299,7 +311,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps)) # self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout)) # self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
self.ffns.append( self.ffns.append(
TFTransformerFFN(self.dim, self.hidden_dim, self.dim, config=config, name="ffns_._{}".format(i)) TFXLMTransformerFFN(self.dim, self.hidden_dim, self.dim, config=config, name="ffns_._{}".format(i))
) )
self.layer_norm2.append( self.layer_norm2.append(
tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2_._{}".format(i)) tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2_._{}".format(i))
...@@ -308,6 +320,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -308,6 +320,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
if hasattr(config, "pruned_heads"): if hasattr(config, "pruned_heads"):
pruned_heads = config.pruned_heads.copy().items() pruned_heads = config.pruned_heads.copy().items()
config.pruned_heads = {} config.pruned_heads = {}
for layer, heads in pruned_heads: for layer, heads in pruned_heads:
if self.attentions[int(layer)].n_heads == config.n_heads: if self.attentions[int(layer)].n_heads == config.n_heads:
self.prune_heads({int(layer): list(map(int, heads))}) self.prune_heads({int(layer): list(map(int, heads))})
...@@ -398,7 +411,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -398,7 +411,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# check inputs # check inputs
# assert shape_list(lengths)[0] == bs # assert shape_list(lengths)[0] == bs
tf.debugging.assert_equal(shape_list(lengths)[0], bs) tf.debugging.assert_equal(
shape_list(lengths)[0], bs
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
# assert lengths.max().item() <= slen # assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None) # assert (src_enc is None) == (src_len is None)
...@@ -416,13 +431,17 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -416,13 +431,17 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
position_ids = tf.expand_dims(tf.range(slen), axis=0) position_ids = tf.expand_dims(tf.range(slen), axis=0)
else: else:
# assert shape_list(position_ids) == [bs, slen] # (slen, bs) # assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(shape_list(position_ids), [bs, slen]) tf.debugging.assert_equal(
shape_list(position_ids), [bs, slen]
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
# position_ids = position_ids.transpose(0, 1) # position_ids = position_ids.transpose(0, 1)
# langs # langs
if langs is not None: if langs is not None:
# assert shape_list(langs) == [bs, slen] # (slen, bs) # assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(shape_list(langs), [bs, slen]) tf.debugging.assert_equal(
shape_list(langs), [bs, slen]
), f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched"
# langs = langs.transpose(0, 1) # langs = langs.transpose(0, 1)
# Prepare head mask if needed # Prepare head mask if needed
...@@ -455,6 +474,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -455,6 +474,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
tensor = tensor + self.lang_embeddings(langs) tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None: if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids) tensor = tensor + self.embeddings(token_type_ids)
tensor = self.layer_norm_emb(tensor) tensor = self.layer_norm_emb(tensor)
tensor = self.dropout(tensor, training=training) tensor = self.dropout(tensor, training=training)
tensor = tensor * mask[..., tf.newaxis] tensor = tensor * mask[..., tf.newaxis]
...@@ -462,6 +482,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -462,6 +482,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# transformer layers # transformer layers
hidden_states = () if output_hidden_states else None hidden_states = () if output_hidden_states else None
attentions = () if output_attentions else None attentions = () if output_attentions else None
for i in range(self.n_layers): for i in range(self.n_layers):
if output_hidden_states: if output_hidden_states:
hidden_states = hidden_states + (tensor,) hidden_states = hidden_states + (tensor,)
...@@ -471,8 +492,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -471,8 +492,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if output_attentions: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training) attn = self.dropout(attn, training=training)
tensor = tensor + attn tensor = tensor + attn
tensor = self.layer_norm1[i](tensor) tensor = self.layer_norm1[i](tensor)
...@@ -502,6 +525,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -502,6 +525,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
if not return_dict: if not return_dict:
return tuple(v for v in [tensor, hidden_states, attentions] if v is not None) return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions) return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
...@@ -691,9 +715,11 @@ class TFXLMPredLayer(tf.keras.layers.Layer): ...@@ -691,9 +715,11 @@ class TFXLMPredLayer(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs): def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.asm = config.asm self.asm = config.asm
self.n_words = config.n_words self.n_words = config.n_words
self.pad_index = config.pad_index self.pad_index = config.pad_index
if config.asm is False: if config.asm is False:
self.input_embeddings = input_embeddings self.input_embeddings = input_embeddings
else: else:
...@@ -709,11 +735,13 @@ class TFXLMPredLayer(tf.keras.layers.Layer): ...@@ -709,11 +735,13 @@ class TFXLMPredLayer(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
# The output weights are the same as the input embeddings, but there is an output-only bias for each token. # The output weights are the same as the input embeddings, but there is an output-only bias for each token.
self.bias = self.add_weight(shape=(self.n_words,), initializer="zeros", trainable=True, name="bias") self.bias = self.add_weight(shape=(self.n_words,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape) super().build(input_shape)
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias hidden_states = hidden_states + self.bias
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment