Unverified Commit d161ed16 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Update the TF models to remove their interdependencies (#7238)

* Refacto the models to remove their interdependencies

* Fix Flaubert model

* Fix Flaubert

* Fix XLM

* Fix Albert

* Fix Roberta

* Fix Albert

* Fix Flaubert

* Apply style + remove unused imports

* Fix Distilbert

* remove unused import

* fix Distilbert

* Fix Flaubert

* Apply style

* Fix Flaubert

* Add the copy comments for the check_copies script

* Fix MobileBert model name

* Address Morgan's comments

* Fix typo

* Oops typo
parent 0cffa424
...@@ -31,7 +31,6 @@ from .file_utils import ( ...@@ -31,7 +31,6 @@ from .file_utils import (
add_start_docstrings_to_callable, add_start_docstrings_to_callable,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_tf_bert import TFBertSelfAttention
from .modeling_tf_outputs import ( from .modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPooling, TFBaseModelOutputWithPooling,
...@@ -181,82 +180,6 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer): ...@@ -181,82 +180,6 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
return tf.reshape(logits, [batch_size, length, self.config.vocab_size]) return tf.reshape(logits, [batch_size, length, self.config.vocab_size])
class TFAlbertSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
assert (
config.hidden_size % config.num_attention_heads == 0
), f"Hidden size {config.hidden_size} not dividable by number of heads {config.num_attention_heads}"
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.output_attentions = config.output_attentions
self.query = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
)
self.key = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
)
self.value = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# Take the dot product between "query" and "key" to get the raw attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
# scale attention_scores
dk = tf.cast(shape_list(key_layer)[-1], tf.float32)
attention_scores = attention_scores / tf.math.sqrt(dk)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFAlbertModel call() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs, training=training)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
context_layer = tf.reshape(
context_layer, (batch_size, -1, self.all_head_size)
) # (batch_size, seq_len_q, all_head_size)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class TFAlbertSelfOutput(tf.keras.layers.Layer): class TFAlbertSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -273,7 +196,7 @@ class TFAlbertSelfOutput(tf.keras.layers.Layer): ...@@ -273,7 +196,7 @@ class TFAlbertSelfOutput(tf.keras.layers.Layer):
return hidden_states return hidden_states
class TFAlbertAttention(TFBertSelfAttention): class TFAlbertAttention(tf.keras.layers.Layer):
""" Contains the complete attention sublayer, including both dropouts and layer norm. """ """ Contains the complete attention sublayer, including both dropouts and layer norm. """
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
...@@ -281,6 +204,19 @@ class TFAlbertAttention(TFBertSelfAttention): ...@@ -281,6 +204,19 @@ class TFAlbertAttention(TFBertSelfAttention):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads
assert config.hidden_size % config.num_attention_heads == 0
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
)
self.key = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
)
self.value = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
) )
...@@ -290,6 +226,11 @@ class TFAlbertAttention(TFBertSelfAttention): ...@@ -290,6 +226,11 @@ class TFAlbertAttention(TFBertSelfAttention):
self.attention_dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) self.attention_dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
self.output_dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.output_dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def transpose_for_scores(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3])
def prune_heads(self, heads): def prune_heads(self, heads):
raise NotImplementedError raise NotImplementedError
...@@ -342,6 +283,7 @@ class TFAlbertAttention(TFBertSelfAttention): ...@@ -342,6 +283,7 @@ class TFAlbertAttention(TFBertSelfAttention):
# add attentions if we output them # add attentions if we output them
outputs = (attention_output,) + self_outputs[1:] outputs = (attention_output,) + self_outputs[1:]
return outputs return outputs
......
...@@ -93,6 +93,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -93,6 +93,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.initializer_range = config.initializer_range self.initializer_range = config.initializer_range
...@@ -124,6 +125,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -124,6 +125,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
shape=[self.vocab_size, self.hidden_size], shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
super().build(input_shape) super().build(input_shape)
def call( def call(
...@@ -273,6 +275,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -273,6 +275,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
class TFBertSelfOutput(tf.keras.layers.Layer): class TFBertSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
) )
...@@ -290,6 +293,7 @@ class TFBertSelfOutput(tf.keras.layers.Layer): ...@@ -290,6 +293,7 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
class TFBertAttention(tf.keras.layers.Layer): class TFBertAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.self_attention = TFBertSelfAttention(config, name="self") self.self_attention = TFBertSelfAttention(config, name="self")
self.dense_output = TFBertSelfOutput(config, name="output") self.dense_output = TFBertSelfOutput(config, name="output")
...@@ -309,6 +313,7 @@ class TFBertAttention(tf.keras.layers.Layer): ...@@ -309,6 +313,7 @@ class TFBertAttention(tf.keras.layers.Layer):
class TFBertIntermediate(tf.keras.layers.Layer): class TFBertIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
) )
...@@ -328,6 +333,7 @@ class TFBertIntermediate(tf.keras.layers.Layer): ...@@ -328,6 +333,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
class TFBertOutput(tf.keras.layers.Layer): class TFBertOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
) )
...@@ -345,6 +351,7 @@ class TFBertOutput(tf.keras.layers.Layer): ...@@ -345,6 +351,7 @@ class TFBertOutput(tf.keras.layers.Layer):
class TFBertLayer(tf.keras.layers.Layer): class TFBertLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.attention = TFBertAttention(config, name="attention") self.attention = TFBertAttention(config, name="attention")
self.intermediate = TFBertIntermediate(config, name="intermediate") self.intermediate = TFBertIntermediate(config, name="intermediate")
self.bert_output = TFBertOutput(config, name="output") self.bert_output = TFBertOutput(config, name="output")
...@@ -364,6 +371,7 @@ class TFBertLayer(tf.keras.layers.Layer): ...@@ -364,6 +371,7 @@ class TFBertLayer(tf.keras.layers.Layer):
class TFBertEncoder(tf.keras.layers.Layer): class TFBertEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)] self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
def call( def call(
...@@ -397,6 +405,7 @@ class TFBertEncoder(tf.keras.layers.Layer): ...@@ -397,6 +405,7 @@ class TFBertEncoder(tf.keras.layers.Layer):
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput( return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
) )
...@@ -405,6 +414,7 @@ class TFBertEncoder(tf.keras.layers.Layer): ...@@ -405,6 +414,7 @@ class TFBertEncoder(tf.keras.layers.Layer):
class TFBertPooler(tf.keras.layers.Layer): class TFBertPooler(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.hidden_size, config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range), kernel_initializer=get_initializer(config.initializer_range),
...@@ -424,6 +434,7 @@ class TFBertPooler(tf.keras.layers.Layer): ...@@ -424,6 +434,7 @@ class TFBertPooler(tf.keras.layers.Layer):
class TFBertPredictionHeadTransform(tf.keras.layers.Layer): class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
) )
...@@ -446,6 +457,7 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer): ...@@ -446,6 +457,7 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
class TFBertLMPredictionHead(tf.keras.layers.Layer): class TFBertLMPredictionHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs): def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.transform = TFBertPredictionHeadTransform(config, name="transform") self.transform = TFBertPredictionHeadTransform(config, name="transform")
...@@ -455,6 +467,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer): ...@@ -455,6 +467,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape) super().build(input_shape)
def call(self, hidden_states): def call(self, hidden_states):
...@@ -468,6 +481,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer): ...@@ -468,6 +481,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
class TFBertMLMHead(tf.keras.layers.Layer): class TFBertMLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs): def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.predictions = TFBertLMPredictionHead(config, input_embeddings, name="predictions") self.predictions = TFBertLMPredictionHead(config, input_embeddings, name="predictions")
def call(self, sequence_output): def call(self, sequence_output):
...@@ -479,6 +493,7 @@ class TFBertMLMHead(tf.keras.layers.Layer): ...@@ -479,6 +493,7 @@ class TFBertMLMHead(tf.keras.layers.Layer):
class TFBertNSPHead(tf.keras.layers.Layer): class TFBertNSPHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.seq_relationship = tf.keras.layers.Dense( self.seq_relationship = tf.keras.layers.Dense(
2, kernel_initializer=get_initializer(config.initializer_range), name="seq_relationship" 2, kernel_initializer=get_initializer(config.initializer_range), name="seq_relationship"
) )
...@@ -495,6 +510,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -495,6 +510,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.initializer_range = config.initializer_range self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
...@@ -571,6 +587,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -571,6 +587,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.fill(input_shape, 1) attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0) token_type_ids = tf.fill(input_shape, 0)
...@@ -588,7 +605,6 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -588,7 +605,6 @@ class TFBertMainLayer(tf.keras.layers.Layer):
# positions we want to attend and -10000.0 for masked positions. # positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype) extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
...@@ -767,6 +783,7 @@ BERT_INPUTS_DOCSTRING = r""" ...@@ -767,6 +783,7 @@ BERT_INPUTS_DOCSTRING = r"""
class TFBertModel(TFBertPreTrainedModel): class TFBertModel(TFBertPreTrainedModel):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name="bert") self.bert = TFBertMainLayer(config, name="bert")
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
...@@ -778,6 +795,7 @@ class TFBertModel(TFBertPreTrainedModel): ...@@ -778,6 +795,7 @@ class TFBertModel(TFBertPreTrainedModel):
) )
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs) outputs = self.bert(inputs, **kwargs)
return outputs return outputs
...@@ -818,7 +836,6 @@ class TFBertForPreTraining(TFBertPreTrainedModel): ...@@ -818,7 +836,6 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
return_dict = kwargs.get("return_dict") return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.bert.return_dict return_dict = return_dict if return_dict is not None else self.bert.return_dict
outputs = self.bert(inputs, **kwargs) outputs = self.bert(inputs, **kwargs)
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False)) prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
seq_relationship_score = self.nsp(pooled_output) seq_relationship_score = self.nsp(pooled_output)
...@@ -880,6 +897,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -880,6 +897,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
""" """
return_dict = return_dict if return_dict is not None else self.bert.return_dict return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9: if len(inputs) > 9:
...@@ -902,7 +920,6 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -902,7 +920,6 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=training) prediction_scores = self.mlm(sequence_output, training=training)
loss = None if labels is None else self.compute_loss(labels, prediction_scores) loss = None if labels is None else self.compute_loss(labels, prediction_scores)
if not return_dict: if not return_dict:
...@@ -956,6 +973,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -956,6 +973,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
Indices should be in ``[0, ..., config.vocab_size - 1]``. Indices should be in ``[0, ..., config.vocab_size - 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.bert.return_dict return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9: if len(inputs) > 9:
...@@ -978,8 +996,8 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -978,8 +996,8 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.mlm(sequence_output, training=training) logits = self.mlm(sequence_output, training=training)
loss = None loss = None
if labels is not None: if labels is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
logits = logits[:, :-1] logits = logits[:, :-1]
...@@ -1033,7 +1051,6 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel): ...@@ -1033,7 +1051,6 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
return_dict = kwargs.get("return_dict") return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.bert.return_dict return_dict = return_dict if return_dict is not None else self.bert.return_dict
outputs = self.bert(inputs, **kwargs) outputs = self.bert(inputs, **kwargs)
pooled_output = outputs[1] pooled_output = outputs[1]
seq_relationship_score = self.nsp(pooled_output) seq_relationship_score = self.nsp(pooled_output)
...@@ -1055,8 +1072,8 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel): ...@@ -1055,8 +1072,8 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss): class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name="bert") self.bert = TFBertMainLayer(config, name="bert")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense( self.classifier = tf.keras.layers.Dense(
...@@ -1092,6 +1109,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific ...@@ -1092,6 +1109,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.bert.return_dict return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9: if len(inputs) > 9:
...@@ -1113,10 +1131,8 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific ...@@ -1113,10 +1131,8 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
) )
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training) pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if labels is None else self.compute_loss(labels, logits)
if not return_dict: if not return_dict:
...@@ -1208,6 +1224,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1208,6 +1224,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
assert len(inputs) <= 10, "Too many inputs." assert len(inputs) <= 10, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
return_dict = return_dict if return_dict is not None else self.bert.return_dict return_dict = return_dict if return_dict is not None else self.bert.return_dict
if input_ids is not None: if input_ids is not None:
...@@ -1242,7 +1259,6 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1242,7 +1259,6 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
pooled_output = self.dropout(pooled_output, training=training) pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits) loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
if not return_dict: if not return_dict:
...@@ -1265,8 +1281,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1265,8 +1281,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss): class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name="bert") self.bert = TFBertMainLayer(config, name="bert")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense( self.classifier = tf.keras.layers.Dense(
...@@ -1300,6 +1316,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -1300,6 +1316,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
Indices should be in ``[0, ..., config.num_labels - 1]``. Indices should be in ``[0, ..., config.num_labels - 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.bert.return_dict return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9: if len(inputs) > 9:
...@@ -1319,12 +1336,9 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -1319,12 +1336,9 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training) sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if labels is None else self.compute_loss(labels, logits)
if not return_dict: if not return_dict:
...@@ -1347,8 +1361,8 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -1347,8 +1361,8 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss): class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name="bert") self.bert = TFBertMainLayer(config, name="bert")
self.qa_outputs = tf.keras.layers.Dense( self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
...@@ -1387,6 +1401,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -1387,6 +1401,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
Position outside of the sequence are not taken into account for computing the loss. Position outside of the sequence are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.bert.return_dict return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
start_positions = inputs[9] if len(inputs) > 9 else start_positions start_positions = inputs[9] if len(inputs) > 9 else start_positions
end_positions = inputs[10] if len(inputs) > 10 else end_positions end_positions = inputs[10] if len(inputs) > 10 else end_positions
...@@ -1408,15 +1423,13 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -1408,15 +1423,13 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = tf.split(logits, 2, axis=-1) start_logits, end_logits = tf.split(logits, 2, axis=-1)
start_logits = tf.squeeze(start_logits, axis=-1) start_logits = tf.squeeze(start_logits, axis=-1)
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions} labels = {"start_position": start_positions}
labels["end_position"] = end_positions labels["end_position"] = end_positions
......
...@@ -16,8 +16,6 @@ ...@@ -16,8 +16,6 @@
""" """
import math
import tensorflow as tf import tensorflow as tf
from .activations_tf import get_tf_activation from .activations_tf import get_tf_activation
...@@ -217,9 +215,8 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): ...@@ -217,9 +215,8 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
k_length = shape_list(key)[1] k_length = shape_list(key)[1]
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim) # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
# assert key.size() == value.size() # assert key.size() == value.size()
dim_per_head = tf.math.divide(self.dim, self.n_heads)
dim_per_head = self.dim // self.n_heads dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
mask_reshape = [bs, 1, 1, k_length] mask_reshape = [bs, 1, 1, k_length]
def shape(x): def shape(x):
...@@ -233,17 +230,16 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): ...@@ -233,17 +230,16 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
q = tf.cast(q, dtype=tf.float32)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32)))
k = tf.cast(k, dtype=q.dtype)
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, q_length, k_length) scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, q_length, k_length)
mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen) mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, q_length, k_length) # scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, q_length, k_length)
scores_dtype = scores.dtype mask = tf.cast(mask, dtype=scores.dtype)
# calculate `scores` in `tf.float32` to avoid numeric overflow scores = scores - 1e30 * (1.0 - mask)
scores = tf.cast(scores, dtype=tf.float32) - 1e30 * (1.0 - tf.cast(mask, dtype=tf.float32)) weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = tf.cast(tf.nn.softmax(scores, axis=-1), dtype=scores_dtype) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen) weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
# Mask heads if we want to # Mask heads if we want to
......
...@@ -13,7 +13,6 @@ from .file_utils import ( ...@@ -13,7 +13,6 @@ from .file_utils import (
add_start_docstrings_to_callable, add_start_docstrings_to_callable,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_tf_bert import TFBertEncoder, TFBertPreTrainedModel
from .modeling_tf_outputs import ( from .modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFMaskedLMOutput, TFMaskedLMOutput,
...@@ -25,6 +24,7 @@ from .modeling_tf_outputs import ( ...@@ -25,6 +24,7 @@ from .modeling_tf_outputs import (
from .modeling_tf_utils import ( from .modeling_tf_utils import (
TFMaskedLanguageModelingLoss, TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss, TFMultipleChoiceLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss, TFQuestionAnsweringLoss,
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFSequenceSummary, TFSequenceSummary,
...@@ -53,15 +53,253 @@ TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -53,15 +53,253 @@ TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.modeling_tf_bert.TFBertSelfAttention
class TFElectraSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
assert config.hidden_size % config.num_attention_heads == 0
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
)
self.key = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
)
self.value = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = tf.matmul(
query_layer, key_layer, transpose_b=True
) # (batch size, num_heads, seq_len_q, seq_len_k)
dk = tf.cast(shape_list(key_layer)[-1], attention_scores.dtype) # scale attention_scores
attention_scores = attention_scores / tf.math.sqrt(dk)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs, training=training)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
context_layer = tf.reshape(
context_layer, (batch_size, -1, self.all_head_size)
) # (batch_size, seq_len_q, all_head_size)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.modeling_tf_bert.TFBertSelfOutput
class TFElectraSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, hidden_states, input_tensor, training=False):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# Copied from from transformers.modeling_tf_bert.TFBertAttention with Bert->Electra
class TFElectraAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.self_attention = TFElectraSelfAttention(config, name="self")
self.dense_output = TFElectraSelfOutput(config, name="output")
def prune_heads(self, heads):
raise NotImplementedError
def call(self, input_tensor, attention_mask, head_mask, output_attentions, training=False):
self_outputs = self.self_attention(
input_tensor, attention_mask, head_mask, output_attentions, training=training
)
attention_output = self.dense_output(self_outputs[0], input_tensor, training=training)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
# Copied from transformers.modeling_tf_bert.TFBertIntermediate
class TFElectraIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.modeling_tf_bert.TFBertOutput
class TFElectraOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, hidden_states, input_tensor, training=False):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# Copied from transformers.modeling_tf_bert.TFBertLayer with Bert->Electra
class TFElectraLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.attention = TFElectraAttention(config, name="attention")
self.intermediate = TFElectraIntermediate(config, name="intermediate")
self.bert_output = TFElectraOutput(config, name="output")
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
attention_outputs = self.attention(
hidden_states, attention_mask, head_mask, output_attentions, training=training
)
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.bert_output(intermediate_output, attention_output, training=training)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs
# Copied from transformers.modeling_tf_bert.TFBertEncoder with Bert->Electra
class TFElectraEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.layer = [TFElectraLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
def call(
self,
hidden_states,
attention_mask,
head_mask,
output_attentions,
output_hidden_states,
return_dict,
training=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states, attention_mask, head_mask[i], output_attentions, training=training
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
# Copied from transformers.modeling_tf_bert.TFBertPooler
class TFElectraPooler(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="dense",
)
def call(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
return pooled_output
class TFElectraEmbeddings(tf.keras.layers.Layer): class TFElectraEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embedding_size = config.embedding_size self.embedding_size = config.embedding_size
self.initializer_range = config.initializer_range self.initializer_range = config.initializer_range
self.position_embeddings = tf.keras.layers.Embedding( self.position_embeddings = tf.keras.layers.Embedding(
config.max_position_embeddings, config.max_position_embeddings,
config.embedding_size, config.embedding_size,
...@@ -90,11 +328,13 @@ class TFElectraEmbeddings(tf.keras.layers.Layer): ...@@ -90,11 +328,13 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
shape=[self.vocab_size, self.embedding_size], shape=[self.vocab_size, self.embedding_size],
initializer=get_initializer(self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
super().build(input_shape) super().build(input_shape)
# Copied from transformers.modeling_tf_bert.TFBertEmbeddings.call
def call( def call(
self, self,
input_ids, input_ids=None,
position_ids=None, position_ids=None,
token_type_ids=None, token_type_ids=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -122,6 +362,7 @@ class TFElectraEmbeddings(tf.keras.layers.Layer): ...@@ -122,6 +362,7 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
else: else:
raise ValueError("mode {} is not valid.".format(mode)) raise ValueError("mode {} is not valid.".format(mode))
# Copied from transformers.modeling_tf_bert.TFBertEmbeddings._embedding
def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False): def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
"""Applies embedding based on inputs tensor.""" """Applies embedding based on inputs tensor."""
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
...@@ -132,19 +373,22 @@ class TFElectraEmbeddings(tf.keras.layers.Layer): ...@@ -132,19 +373,22 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs_embeds)[:-1]
seq_length = input_shape[1] seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
if token_type_ids is None: if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0) token_type_ids = tf.fill(input_shape, 0)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = tf.gather(self.word_embeddings, input_ids) inputs_embeds = tf.gather(self.word_embeddings, input_ids)
position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype) position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)
token_type_embeddings = tf.cast(self.token_type_embeddings(token_type_ids), inputs_embeds.dtype) token_type_embeddings = tf.cast(self.token_type_embeddings(token_type_ids), inputs_embeds.dtype)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings) embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings, training=training) embeddings = self.dropout(embeddings, training=training)
return embeddings return embeddings
def _linear(self, inputs): def _linear(self, inputs):
...@@ -156,7 +400,6 @@ class TFElectraEmbeddings(tf.keras.layers.Layer): ...@@ -156,7 +400,6 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
""" """
batch_size = shape_list(inputs)[0] batch_size = shape_list(inputs)[0]
length = shape_list(inputs)[1] length = shape_list(inputs)[1]
x = tf.reshape(inputs, [-1, self.embedding_size]) x = tf.reshape(inputs, [-1, self.embedding_size])
logits = tf.matmul(x, self.word_embeddings, transpose_b=True) logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
...@@ -194,11 +437,47 @@ class TFElectraGeneratorPredictions(tf.keras.layers.Layer): ...@@ -194,11 +437,47 @@ class TFElectraGeneratorPredictions(tf.keras.layers.Layer):
return hidden_states return hidden_states
class TFElectraPreTrainedModel(TFBertPreTrainedModel): class TFElectraPreTrainedModel(TFPreTrainedModel):
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = ElectraConfig config_class = ElectraConfig
base_model_prefix = "electra" base_model_prefix = "electra"
@keras_serializable
class TFElectraMainLayer(tf.keras.layers.Layer):
config_class = ElectraConfig
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.embeddings = TFElectraEmbeddings(config, name="embeddings")
if config.embedding_size != config.hidden_size:
self.embeddings_project = tf.keras.layers.Dense(config.hidden_size, name="embeddings_project")
self.encoder = TFElectraEncoder(config, name="encoder")
self.config = config
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def _prune_heads(self, heads_to_prune):
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
raise NotImplementedError
def get_extended_attention_mask(self, attention_mask, input_shape, dtype): def get_extended_attention_mask(self, attention_mask, input_shape, dtype):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.fill(input_shape, 1) attention_mask = tf.fill(input_shape, 1)
...@@ -215,7 +494,6 @@ class TFElectraPreTrainedModel(TFBertPreTrainedModel): ...@@ -215,7 +494,6 @@ class TFElectraPreTrainedModel(TFBertPreTrainedModel):
# positions we want to attend and -10000.0 for masked positions. # positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
extended_attention_mask = tf.cast(extended_attention_mask, dtype) extended_attention_mask = tf.cast(extended_attention_mask, dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
...@@ -229,38 +507,6 @@ class TFElectraPreTrainedModel(TFBertPreTrainedModel): ...@@ -229,38 +507,6 @@ class TFElectraPreTrainedModel(TFBertPreTrainedModel):
return head_mask return head_mask
@keras_serializable
class TFElectraMainLayer(TFElectraPreTrainedModel):
config_class = ElectraConfig
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.embeddings = TFElectraEmbeddings(config, name="embeddings")
if config.embedding_size != config.hidden_size:
self.embeddings_project = tf.keras.layers.Dense(config.hidden_size, name="embeddings_project")
self.encoder = TFBertEncoder(config, name="encoder")
self.config = config
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def _prune_heads(self, heads_to_prune):
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
raise NotImplementedError
def call( def call(
self, self,
inputs, inputs,
...@@ -316,11 +562,11 @@ class TFElectraMainLayer(TFElectraPreTrainedModel): ...@@ -316,11 +562,11 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.fill(input_shape, 1) attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0) token_type_ids = tf.fill(input_shape, 0)
hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, hidden_states.dtype) extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, hidden_states.dtype)
head_mask = self.get_head_mask(head_mask) head_mask = self.get_head_mask(head_mask)
...@@ -462,6 +708,7 @@ ELECTRA_INPUTS_DOCSTRING = r""" ...@@ -462,6 +708,7 @@ ELECTRA_INPUTS_DOCSTRING = r"""
class TFElectraModel(TFElectraPreTrainedModel): class TFElectraModel(TFElectraPreTrainedModel):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.electra = TFElectraMainLayer(config, name="electra") self.electra = TFElectraMainLayer(config, name="electra")
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
...@@ -473,6 +720,7 @@ class TFElectraModel(TFElectraPreTrainedModel): ...@@ -473,6 +720,7 @@ class TFElectraModel(TFElectraPreTrainedModel):
) )
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.electra(inputs, **kwargs) outputs = self.electra(inputs, **kwargs)
return outputs return outputs
...@@ -521,7 +769,6 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel): ...@@ -521,7 +769,6 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
>>> scores = outputs[0] >>> scores = outputs[0]
""" """
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
discriminator_hidden_states = self.electra( discriminator_hidden_states = self.electra(
input_ids, input_ids,
attention_mask, attention_mask,
...@@ -550,16 +797,19 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel): ...@@ -550,16 +797,19 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
class TFElectraMaskedLMHead(tf.keras.layers.Layer): class TFElectraMaskedLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs): def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.input_embeddings = input_embeddings self.input_embeddings = input_embeddings
def build(self, input_shape): def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape) super().build(input_shape)
def call(self, hidden_states, training=False): def call(self, hidden_states, training=False):
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias hidden_states = hidden_states + self.bias
return hidden_states return hidden_states
...@@ -577,10 +827,12 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -577,10 +827,12 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.electra = TFElectraMainLayer(config, name="electra") self.electra = TFElectraMainLayer(config, name="electra")
self.generator_predictions = TFElectraGeneratorPredictions(config, name="generator_predictions") self.generator_predictions = TFElectraGeneratorPredictions(config, name="generator_predictions")
if isinstance(config.hidden_act, str): if isinstance(config.hidden_act, str):
self.activation = get_tf_activation(config.hidden_act) self.activation = get_tf_activation(config.hidden_act)
else: else:
self.activation = config.hidden_act self.activation = config.hidden_act
self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head") self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head")
def get_output_embeddings(self): def get_output_embeddings(self):
...@@ -615,8 +867,10 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -615,8 +867,10 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
""" """
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
if isinstance(input_ids, (tuple, list)): if isinstance(input_ids, (tuple, list)):
labels = input_ids[9] if len(input_ids) > 9 else labels labels = input_ids[9] if len(input_ids) > 9 else labels
if len(input_ids) > 9: if len(input_ids) > 9:
input_ids = input_ids[:9] input_ids = input_ids[:9]
elif isinstance(input_ids, (dict, BatchEncoding)): elif isinstance(input_ids, (dict, BatchEncoding)):
...@@ -637,11 +891,11 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -637,11 +891,11 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
generator_sequence_output = generator_hidden_states[0] generator_sequence_output = generator_hidden_states[0]
prediction_scores = self.generator_predictions(generator_sequence_output, training=training) prediction_scores = self.generator_predictions(generator_sequence_output, training=training)
prediction_scores = self.generator_lm_head(prediction_scores, training=training) prediction_scores = self.generator_lm_head(prediction_scores, training=training)
loss = None if labels is None else self.compute_loss(labels, prediction_scores) loss = None if labels is None else self.compute_loss(labels, prediction_scores)
if not return_dict: if not return_dict:
output = (prediction_scores,) + generator_hidden_states[1:] output = (prediction_scores,) + generator_hidden_states[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TFMaskedLMOutput( return TFMaskedLMOutput(
...@@ -657,6 +911,7 @@ class TFElectraClassificationHead(tf.keras.layers.Layer): ...@@ -657,6 +911,7 @@ class TFElectraClassificationHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
) )
...@@ -717,8 +972,10 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla ...@@ -717,8 +972,10 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
if isinstance(input_ids, (tuple, list)): if isinstance(input_ids, (tuple, list)):
labels = input_ids[9] if len(input_ids) > 9 else labels labels = input_ids[9] if len(input_ids) > 9 else labels
if len(input_ids) > 9: if len(input_ids) > 9:
input_ids = input_ids[:9] input_ids = input_ids[:9]
elif isinstance(input_ids, (dict, BatchEncoding)): elif isinstance(input_ids, (dict, BatchEncoding)):
...@@ -737,11 +994,11 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla ...@@ -737,11 +994,11 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
training=training, training=training,
) )
logits = self.classifier(outputs[0]) logits = self.classifier(outputs[0])
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if labels is None else self.compute_loss(labels, logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutput( return TFSequenceClassifierOutput(
...@@ -831,6 +1088,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss) ...@@ -831,6 +1088,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
assert len(inputs) <= 10, "Too many inputs." assert len(inputs) <= 10, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
if input_ids is not None: if input_ids is not None:
...@@ -864,11 +1122,11 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss) ...@@ -864,11 +1122,11 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
logits = self.sequence_summary(outputs[0]) logits = self.sequence_summary(outputs[0])
logits = self.classifier(logits) logits = self.classifier(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits) loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
if not return_dict: if not return_dict:
output = (reshaped_logits,) + outputs[1:] output = (reshaped_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TFMultipleChoiceModelOutput( return TFMultipleChoiceModelOutput(
...@@ -922,8 +1180,10 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific ...@@ -922,8 +1180,10 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
Indices should be in ``[0, ..., config.num_labels - 1]``. Indices should be in ``[0, ..., config.num_labels - 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9: if len(inputs) > 9:
inputs = inputs[:9] inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
...@@ -944,11 +1204,11 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific ...@@ -944,11 +1204,11 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
discriminator_sequence_output = discriminator_hidden_states[0] discriminator_sequence_output = discriminator_hidden_states[0]
discriminator_sequence_output = self.dropout(discriminator_sequence_output) discriminator_sequence_output = self.dropout(discriminator_sequence_output)
logits = self.classifier(discriminator_sequence_output) logits = self.classifier(discriminator_sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if labels is None else self.compute_loss(labels, logits)
if not return_dict: if not return_dict:
output = (logits,) + discriminator_hidden_states[1:] output = (logits,) + discriminator_hidden_states[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TFTokenClassifierOutput( return TFTokenClassifierOutput(
...@@ -967,8 +1227,8 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific ...@@ -967,8 +1227,8 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnsweringLoss): class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnsweringLoss):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.num_labels = config.num_labels
self.electra = TFElectraMainLayer(config, name="electra") self.electra = TFElectraMainLayer(config, name="electra")
self.qa_outputs = tf.keras.layers.Dense( self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
...@@ -1007,9 +1267,11 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin ...@@ -1007,9 +1267,11 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
Position outside of the sequence are not taken into account for computing the loss. Position outside of the sequence are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
start_positions = inputs[9] if len(inputs) > 9 else start_positions start_positions = inputs[9] if len(inputs) > 9 else start_positions
end_positions = inputs[10] if len(inputs) > 10 else end_positions end_positions = inputs[10] if len(inputs) > 10 else end_positions
if len(inputs) > 9: if len(inputs) > 9:
inputs = inputs[:9] inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
...@@ -1029,13 +1291,12 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin ...@@ -1029,13 +1291,12 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
training=training, training=training,
) )
discriminator_sequence_output = discriminator_hidden_states[0] discriminator_sequence_output = discriminator_hidden_states[0]
logits = self.qa_outputs(discriminator_sequence_output) logits = self.qa_outputs(discriminator_sequence_output)
start_logits, end_logits = tf.split(logits, 2, axis=-1) start_logits, end_logits = tf.split(logits, 2, axis=-1)
start_logits = tf.squeeze(start_logits, axis=-1) start_logits = tf.squeeze(start_logits, axis=-1)
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions} labels = {"start_position": start_positions}
labels["end_position"] = end_positions labels["end_position"] = end_positions
...@@ -1046,6 +1307,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin ...@@ -1046,6 +1307,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
start_logits, start_logits,
end_logits, end_logits,
) + discriminator_hidden_states[1:] ) + discriminator_hidden_states[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TFQuestionAnsweringModelOutput( return TFQuestionAnsweringModelOutput(
......
...@@ -15,24 +15,23 @@ ...@@ -15,24 +15,23 @@
""" TF 2.0 Flaubert model. """ TF 2.0 Flaubert model.
""" """
import random import itertools
from dataclasses import dataclass
from typing import Optional, Tuple
import tensorflow as tf import tensorflow as tf
from transformers.activations_tf import get_tf_activation
from .configuration_flaubert import FlaubertConfig from .configuration_flaubert import FlaubertConfig
from .file_utils import add_start_docstrings from .file_utils import ModelOutput, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_outputs import TFBaseModelOutput from .modeling_tf_outputs import TFBaseModelOutput
from .modeling_tf_utils import keras_serializable, shape_list from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, keras_serializable, shape_list
from .modeling_tf_xlm import ( from .modeling_tf_xlm import (
TFXLMForMultipleChoice, TFXLMForMultipleChoice,
TFXLMForQuestionAnsweringSimple, TFXLMForQuestionAnsweringSimple,
TFXLMForSequenceClassification, TFXLMForSequenceClassification,
TFXLMForTokenClassification, TFXLMForTokenClassification,
TFXLMMainLayer,
TFXLMModel,
TFXLMPredLayer,
TFXLMWithLMHeadModel,
get_masks,
) )
from .tokenization_utils import BatchEncoding from .tokenization_utils import BatchEncoding
from .utils import logging from .utils import logging
...@@ -40,6 +39,9 @@ from .utils import logging ...@@ -40,6 +39,9 @@ from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "FlaubertConfig"
_TOKENIZER_FOR_DOC = "FlaubertTokenizer"
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all Flaubert models at https://huggingface.co/models?filter=flaubert # See all Flaubert models at https://huggingface.co/models?filter=flaubert
] ]
...@@ -155,27 +157,258 @@ FLAUBERT_INPUTS_DOCSTRING = r""" ...@@ -155,27 +157,258 @@ FLAUBERT_INPUTS_DOCSTRING = r"""
""" """
def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
"""
Generate hidden states mask, and optionally an attention mask.
"""
bs = shape_list(lengths)[0]
if padding_mask is not None:
mask = padding_mask
else:
# assert lengths.max().item() <= slen
alen = tf.range(slen)
mask = tf.math.less(alen, lengths[:, tf.newaxis])
# attention mask is the same as mask, or triangular inferior attention (causal)
if causal:
attn_mask = tf.less_equal(
tf.tile(alen[tf.newaxis, tf.newaxis, :], (bs, slen, 1)), alen[tf.newaxis, :, tf.newaxis]
)
else:
attn_mask = mask
# sanity check
# assert shape_list(mask) == [bs, slen]
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
mask = tf.cast(mask, dtype=dtype)
attn_mask = tf.cast(attn_mask, dtype=dtype)
return mask, attn_mask
class TFFlaubertPreTrainedModel(TFPreTrainedModel):
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = FlaubertConfig
base_model_prefix = "transformer"
@property
def dummy_inputs(self):
# Sometimes XLM has language embeddings so don't forget to build them as well if needed
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
if self.config.use_lang_emb and self.config.n_langs > 1:
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
else:
langs_list = None
return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
@add_start_docstrings( @add_start_docstrings(
"The bare Flaubert Model transformer outputting raw hidden-states without any specific head on top.", "The bare Flaubert Model transformer outputing raw hidden-states without any specific head on top.",
FLAUBERT_START_DOCSTRING, FLAUBERT_START_DOCSTRING,
) )
class TFFlaubertModel(TFXLMModel): class TFFlaubertModel(TFFlaubertPreTrainedModel):
config_class = FlaubertConfig
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.transformer = TFFlaubertMainLayer(config, name="transformer") self.transformer = TFFlaubertMainLayer(config, name="transformer")
@add_start_docstrings_to_callable(FLAUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="jplu/tf-flaubert-small-cased",
output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
return outputs
# Copied from transformers.modeling_tf_xlm.TFXLMMultiHeadAttention with XLM->Flaubert
class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer):
NEW_ID = itertools.count()
def __init__(self, n_heads, dim, config, **kwargs):
super().__init__(**kwargs)
self.layer_id = next(TFFlaubertMultiHeadAttention.NEW_ID)
self.dim = dim
self.n_heads = n_heads
self.output_attentions = config.output_attentions
assert self.dim % self.n_heads == 0
self.q_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="q_lin")
self.k_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="k_lin")
self.v_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="v_lin")
self.out_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="out_lin")
self.dropout = tf.keras.layers.Dropout(config.attention_dropout)
self.pruned_heads = set()
def prune_heads(self, heads):
raise NotImplementedError
def call(self, input, mask, kv, cache, head_mask, output_attentions, training=False):
"""
Self-attention (if kv is None) or attention over source sentence (provided by kv).
"""
# Input is (bs, qlen, dim)
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
bs, qlen, dim = shape_list(input)
if kv is None:
klen = qlen if cache is None else cache["slen"] + qlen
else:
klen = shape_list(kv)[1]
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
dim_per_head = tf.math.divide(self.dim, self.n_heads)
dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen)
def shape(x):
""" projection """
return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3))
def unshape(x):
""" compute context """
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))
q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
if kv is None:
k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head)
elif cache is None or self.layer_id not in cache:
k = v = kv
k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head)
if cache is not None:
if self.layer_id in cache:
if kv is None:
k_, v_ = cache[self.layer_id]
k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head)
v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head)
else:
k, v = cache[self.layer_id]
cache[self.layer_id] = (k, v)
q = tf.cast(q, dtype=tf.float32)
q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32))) # (bs, n_heads, qlen, dim_per_head)
k = tf.cast(k, dtype=q.dtype)
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
mask = tf.cast(mask, dtype=scores.dtype)
scores = scores - 1e30 * (1.0 - mask)
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
# Mask heads if we want to
if head_mask is not None:
weights = weights * head_mask
context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
context = unshape(context) # (bs, qlen, dim)
outputs = (self.out_lin(context),)
if output_attentions:
outputs = outputs + (weights,)
return outputs
# Copied from transformers.modeling_tf_xlm.TFXLMTransformerFFN
class TFFlaubertTransformerFFN(tf.keras.layers.Layer):
def __init__(self, in_dim, dim_hidden, out_dim, config, **kwargs):
super().__init__(**kwargs)
self.lin1 = tf.keras.layers.Dense(dim_hidden, kernel_initializer=get_initializer(config.init_std), name="lin1")
self.lin2 = tf.keras.layers.Dense(out_dim, kernel_initializer=get_initializer(config.init_std), name="lin2")
self.act = get_tf_activation("gelu") if config.gelu_activation else get_tf_activation("relu")
self.dropout = tf.keras.layers.Dropout(config.dropout)
def call(self, input, training=False):
x = self.lin1(input)
x = self.act(x)
x = self.lin2(x)
x = self.dropout(x, training=training)
return x
@keras_serializable @keras_serializable
class TFFlaubertMainLayer(TFXLMMainLayer): class TFFlaubertMainLayer(tf.keras.layers.Layer):
config_class = FlaubertConfig
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(**kwargs)
self.n_heads = config.n_heads
self.n_langs = config.n_langs
self.dim = config.emb_dim
self.hidden_dim = self.dim * 4
self.n_words = config.n_words
self.pad_index = config.pad_index
self.causal = config.causal
self.n_layers = config.n_layers
self.use_lang_emb = config.use_lang_emb
self.layerdrop = getattr(config, "layerdrop", 0.0) self.layerdrop = getattr(config, "layerdrop", 0.0)
self.pre_norm = getattr(config, "pre_norm", False) self.pre_norm = getattr(config, "pre_norm", False)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.return_dict = config.use_return_dict self.return_dict = config.use_return_dict
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.position_embeddings = tf.keras.layers.Embedding(
config.max_position_embeddings,
self.dim,
embeddings_initializer=get_initializer(config.embed_init_std),
name="position_embeddings",
)
if config.n_langs > 1 and config.use_lang_emb:
self.lang_embeddings = tf.keras.layers.Embedding(
self.n_langs,
self.dim,
embeddings_initializer=get_initializer(config.embed_init_std),
name="lang_embeddings",
)
self.embeddings = TFSharedEmbeddings(
self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings"
)
self.layer_norm_emb = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm_emb")
self.attentions = []
self.layer_norm1 = []
self.ffns = []
self.layer_norm2 = []
for i in range(self.n_layers):
self.attentions.append(
TFFlaubertMultiHeadAttention(self.n_heads, self.dim, config=config, name="attentions_._{}".format(i))
)
self.layer_norm1.append(
tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1_._{}".format(i))
)
# if self.is_decoder:
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
self.ffns.append(
TFFlaubertTransformerFFN(
self.dim, self.hidden_dim, self.dim, config=config, name="ffns_._{}".format(i)
)
)
self.layer_norm2.append(
tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2_._{}".format(i))
)
def get_input_embeddings(self):
return self.embeddings
def call( def call(
self, self,
...@@ -305,21 +538,26 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -305,21 +538,26 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
inputs_embeds = self.embeddings(input_ids) inputs_embeds = self.embeddings(input_ids)
tensor = inputs_embeds + self.position_embeddings(position_ids) tensor = inputs_embeds + self.position_embeddings(position_ids)
if langs is not None and self.use_lang_emb: if langs is not None and self.use_lang_emb:
tensor = tensor + self.lang_embeddings(langs) tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None: if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids) tensor = tensor + self.embeddings(token_type_ids)
tensor = self.layer_norm_emb(tensor) tensor = self.layer_norm_emb(tensor)
tensor = self.dropout(tensor, training=training) tensor = self.dropout(tensor, training=training)
tensor = tensor * mask[..., tf.newaxis] tensor = tensor * mask[..., tf.newaxis]
# hidden_states and attentions cannot be None in graph mode.
hidden_states = ()
attentions = ()
# transformer layers # transformer layers
hidden_states = () if output_hidden_states else None
attentions = () if output_attentions else None
for i in range(self.n_layers): for i in range(self.n_layers):
# LayerDrop # LayerDrop
dropout_probability = random.uniform(0, 1) dropout_probability = tf.random.uniform([1], 0, 1)
if training and (dropout_probability < self.layerdrop):
if training and tf.less(dropout_probability, self.layerdrop):
continue continue
if output_hidden_states: if output_hidden_states:
...@@ -331,8 +569,10 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -331,8 +569,10 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if output_attentions: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training) attn = self.dropout(attn, training=training)
tensor = tensor + attn tensor = tensor + attn
tensor = self.layer_norm1[i](tensor) tensor = self.layer_norm1[i](tensor)
...@@ -342,8 +582,10 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -342,8 +582,10 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
tensor_normalized, attn_mask, None, cache, head_mask[i], output_attentions, training=training tensor_normalized, attn_mask, None, cache, head_mask[i], output_attentions, training=training
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if output_attentions: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training) attn = self.dropout(attn, training=training)
tensor = tensor + attn tensor = tensor + attn
...@@ -375,23 +617,129 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -375,23 +617,129 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
# move back sequence length to dimension 0 # move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1) # tensor = tensor.transpose(0, 1)
# Set to None here if the output booleans are at False
hidden_states = hidden_states if output_hidden_states else None
attentions = attentions if output_attentions else None
if not return_dict: if not return_dict:
return tuple(v for v in [tensor, hidden_states, attentions] if v is not None) return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions) return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
# Copied from transformers.modeling_tf_xlm.TFXLMPredLayer
class TFFlaubertPredLayer(tf.keras.layers.Layer):
"""
Prediction layer (cross_entropy or adaptive_softmax).
"""
def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs)
self.asm = config.asm
self.n_words = config.n_words
self.pad_index = config.pad_index
if config.asm is False:
self.input_embeddings = input_embeddings
else:
raise NotImplementedError
# self.proj = nn.AdaptiveLogSoftmaxWithLoss(
# in_features=dim,
# n_classes=config.n_words,
# cutoffs=config.asm_cutoffs,
# div_value=config.asm_div_value,
# head_bias=True, # default is False
# )
def build(self, input_shape):
# The output weights are the same as the input embeddings, but there is an output-only bias for each token.
self.bias = self.add_weight(shape=(self.n_words,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
return hidden_states
@dataclass
class TFFlaubertWithLMHeadModelOutput(ModelOutput):
"""
Base class for :class:`~transformers.TFFlaubertWithLMHeadModel` outputs.
Args:
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
@add_start_docstrings( @add_start_docstrings(
"""The Flaubert Model transformer with a language modeling head on top """The Flaubert Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings). """, (linear layer with weights tied to the input embeddings). """,
FLAUBERT_START_DOCSTRING, FLAUBERT_START_DOCSTRING,
) )
class TFFlaubertWithLMHeadModel(TFXLMWithLMHeadModel): class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
config_class = FlaubertConfig
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.transformer = TFFlaubertMainLayer(config, name="transformer") self.transformer = TFFlaubertMainLayer(config, name="transformer")
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj") self.pred_layer = TFFlaubertPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj")
def get_output_embeddings(self):
return self.pred_layer.input_embeddings
def prepare_inputs_for_generation(self, inputs, **kwargs):
mask_token_id = self.config.mask_token_id
lang_id = self.config.lang_id
effective_batch_size = inputs.shape[0]
mask_token = tf.ones((effective_batch_size, 1), dtype=tf.int32) * mask_token_id
inputs = tf.concat([inputs, mask_token], axis=1)
if lang_id is not None:
langs = tf.ones_like(inputs) * lang_id
else:
langs = None
return {"inputs": inputs, "langs": langs}
@add_start_docstrings_to_callable(FLAUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="jplu/tf-flaubert-small-cased",
output_type=TFFlaubertWithLMHeadModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
transformer_outputs = self.transformer(inputs, **kwargs)
output = transformer_outputs[0]
outputs = self.pred_layer(output)
if not return_dict:
return (outputs,) + transformer_outputs[1:]
return TFFlaubertWithLMHeadModelOutput(
logits=outputs, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions
)
@add_start_docstrings( @add_start_docstrings(
......
...@@ -16,16 +16,16 @@ ...@@ -16,16 +16,16 @@
import tensorflow as tf import tensorflow as tf
from transformers.activations_tf import get_tf_activation
from .configuration_longformer import LongformerConfig from .configuration_longformer import LongformerConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import TFBertIntermediate, TFBertOutput, TFBertPooler, TFBertSelfOutput
from .modeling_tf_outputs import ( from .modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPooling, TFBaseModelOutputWithPooling,
TFMaskedLMOutput, TFMaskedLMOutput,
TFQuestionAnsweringModelOutput, TFQuestionAnsweringModelOutput,
) )
from .modeling_tf_roberta import TFRobertaEmbeddings, TFRobertaLMHead
from .modeling_tf_utils import ( from .modeling_tf_utils import (
TFMaskedLanguageModelingLoss, TFMaskedLanguageModelingLoss,
TFPreTrainedModel, TFPreTrainedModel,
...@@ -84,18 +84,280 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se ...@@ -84,18 +84,280 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
return attention_mask return attention_mask
# Copied from transformers.modeling_tf_roberta.TFRobertaLMHead
class TFLongformerLMHead(tf.keras.layers.Layer):
"""Roberta Head for masked language modeling."""
def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs)
self.vocab_size = config.vocab_size
self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.act = get_tf_activation("gelu")
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = input_embeddings
def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def call(self, features):
x = self.dense(features)
x = self.act(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = self.decoder(x, mode="linear") + self.bias
return x
# Copied from transformers.modeling_tf_roberta.TFRobertaEmbeddings
class TFLongformerEmbeddings(tf.keras.layers.Layer):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
"""
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.padding_idx = 1
self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
self.initializer_range = config.initializer_range
self.position_embeddings = tf.keras.layers.Embedding(
config.max_position_embeddings,
config.hidden_size,
embeddings_initializer=get_initializer(self.initializer_range),
name="position_embeddings",
)
self.token_type_embeddings = tf.keras.layers.Embedding(
config.type_vocab_size,
config.hidden_size,
embeddings_initializer=get_initializer(self.initializer_range),
name="token_type_embeddings",
)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def build(self, input_shape):
"""Build shared word embedding layer """
with tf.name_scope("word_embeddings"):
# Create and initialize weights. The random normal initializer was chosen
# arbitrarily, and works well.
self.word_embeddings = self.add_weight(
"weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def create_position_ids_from_input_ids(self, x):
"""Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
:param tf.Tensor x:
:return tf.Tensor:
"""
mask = tf.cast(tf.math.not_equal(x, self.padding_idx), dtype=tf.int32)
incremental_indicies = tf.math.cumsum(mask, axis=1) * mask
return incremental_indicies + self.padding_idx
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
"""We are provided embeddings directly. We cannot infer which are padded so just generate
sequential position ids.
:param tf.Tensor inputs_embeds:
:return tf.Tensor:
"""
seq_length = shape_list(inputs_embeds)[1]
position_ids = tf.range(self.padding_idx + 1, seq_length + self.padding_idx + 1, dtype=tf.int32)[tf.newaxis, :]
return position_ids
def call(
self,
input_ids=None,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
mode="embedding",
training=False,
):
"""Get token embeddings of inputs.
Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
mode: string, a valid value is one of "embedding" and "linear".
Returns:
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
linear tensor, float32 with shape [batch_size, length, vocab_size].
Raises:
ValueError: if mode is not valid.
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
if mode == "embedding":
return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
elif mode == "linear":
return self._linear(input_ids)
else:
raise ValueError("mode {} is not valid.".format(mode))
def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
"""Applies embedding based on inputs tensor."""
assert not (input_ids is None and inputs_embeds is None)
if position_ids is None:
if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = self.create_position_ids_from_input_ids(input_ids)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
if input_ids is not None:
input_shape = shape_list(input_ids)
else:
input_shape = shape_list(inputs_embeds)[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs_embeds is None:
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)
token_type_embeddings = tf.cast(self.token_type_embeddings(token_type_ids), inputs_embeds.dtype)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings, training=training)
return embeddings
def _linear(self, inputs):
"""Computes logits by running inputs through a linear layer.
Args:
inputs: A float32 tensor with shape [batch_size, length, hidden_size]
Returns:
float32 tensor with shape [batch_size, length, vocab_size].
"""
batch_size = shape_list(inputs)[0]
length = shape_list(inputs)[1]
x = tf.reshape(inputs, [-1, self.hidden_size])
logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
return tf.reshape(logits, [batch_size, length, self.vocab_size])
# Copied from transformers.modeling_tf_bert.TFBertIntermediate
class TFLongformerIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.modeling_tf_bert.TFBertOutput
class TFLongformerOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, hidden_states, input_tensor, training=False):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# Copied from transformers.modeling_tf_bert.TFBertPooler
class TFLongformerPooler(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="dense",
)
def call(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
return pooled_output
# Copied from transformers.modeling_tf_bert.TFBertSelfOutput
class TFLongformerSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, hidden_states, input_tensor, training=False):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class TFLongformerSelfAttention(tf.keras.layers.Layer): class TFLongformerSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, layer_id, **kwargs): def __init__(self, config, layer_id, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0: if config.hidden_size % config.num_attention_heads != 0:
raise ValueError( raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention " "The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads) "heads (%d)" % (config.hidden_size, config.num_attention_heads)
) )
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_dim = int(config.hidden_size / config.num_attention_heads) self.head_dim = int(config.hidden_size / config.num_attention_heads)
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.query = tf.keras.layers.Dense( self.query = tf.keras.layers.Dense(
self.embed_dim, self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range), kernel_initializer=get_initializer(config.initializer_range),
...@@ -128,13 +390,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -128,13 +390,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
kernel_initializer=get_initializer(config.initializer_range), kernel_initializer=get_initializer(config.initializer_range),
name="value_global", name="value_global",
) )
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
self.global_dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) self.global_dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
self.layer_id = layer_id self.layer_id = layer_id
attention_window = config.attention_window[self.layer_id] attention_window = config.attention_window[self.layer_id]
assert ( assert (
attention_window % 2 == 0 attention_window % 2 == 0
), f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}" ), f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}"
...@@ -173,8 +433,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -173,8 +433,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
query_vectors = self.query(hidden_states) query_vectors = self.query(hidden_states)
key_vectors = self.key(hidden_states) key_vectors = self.key(hidden_states)
value_vectors = self.value(hidden_states) value_vectors = self.value(hidden_states)
batch_size, seq_len, embed_dim = shape_list(hidden_states) batch_size, seq_len, embed_dim = shape_list(hidden_states)
tf.debugging.assert_equal( tf.debugging.assert_equal(
embed_dim, embed_dim,
self.embed_dim, self.embed_dim,
...@@ -183,7 +443,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -183,7 +443,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# normalize query # normalize query
query_vectors /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32)) query_vectors /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
...@@ -217,7 +476,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -217,7 +476,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) = self._get_global_attn_indices(is_index_global_attn) ) = self._get_global_attn_indices(is_index_global_attn)
# this function is only relevant for global attention # this function is only relevant for global attention
attn_scores = tf.cond( attn_scores = tf.cond(
is_global_attn, is_global_attn,
lambda: self._concat_with_global_key_attn_probs( lambda: self._concat_with_global_key_attn_probs(
...@@ -243,7 +501,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -243,7 +501,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# apply dropout # apply dropout
attn_probs = self.dropout(attn_probs, training=training) attn_probs = self.dropout(attn_probs, training=training)
value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
# if global attention, compute sum of global and local attn # if global attention, compute sum of global and local attn
...@@ -266,6 +523,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -266,6 +523,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
[batch_size, seq_len, self.num_heads, self.head_dim], [batch_size, seq_len, self.num_heads, self.head_dim],
message="Unexpected size", message="Unexpected size",
) )
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim)) attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
# compute value for global attention and overwrite to attention output # compute value for global attention and overwrite to attention output
...@@ -303,6 +561,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -303,6 +561,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
outputs = (attn_output, attn_probs) outputs = (attn_output, attn_probs)
return outputs return outputs
@staticmethod @staticmethod
...@@ -322,6 +581,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -322,6 +581,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer)
with an overlap of size window_overlap""" with an overlap of size window_overlap"""
batch_size, seq_len, num_heads, head_dim = shape_list(query) batch_size, seq_len, num_heads, head_dim = shape_list(query)
tf.debugging.assert_equal( tf.debugging.assert_equal(
seq_len % (window_overlap * 2), seq_len % (window_overlap * 2),
0, 0,
...@@ -341,7 +601,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -341,7 +601,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
(batch_size * num_heads, seq_len, head_dim), (batch_size * num_heads, seq_len, head_dim),
) )
key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))
chunked_query = self._chunk(query, window_overlap) chunked_query = self._chunk(query, window_overlap)
chunked_key = self._chunk(key, window_overlap) chunked_key = self._chunk(key, window_overlap)
...@@ -390,7 +649,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -390,7 +649,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
], ],
axis=1, axis=1,
) )
first_chunk_mask = ( first_chunk_mask = (
tf.broadcast_to( tf.broadcast_to(
tf.range(chunks_count + 1)[None, :, None, None], tf.range(chunks_count + 1)[None, :, None, None],
...@@ -403,7 +661,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -403,7 +661,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
< 1 < 1
) )
diagonal_attn_scores_low_triang = tf.where( diagonal_attn_scores_low_triang = tf.where(
first_chunk_mask, first_chunk_mask,
diagonal_attn_scores_first_chunk, diagonal_attn_scores_first_chunk,
...@@ -425,6 +682,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -425,6 +682,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
diagonal_attention_scores = self._mask_invalid_locations(diagonal_attention_scores, window_overlap) diagonal_attention_scores = self._mask_invalid_locations(diagonal_attention_scores, window_overlap)
return diagonal_attention_scores return diagonal_attention_scores
@staticmethod @staticmethod
...@@ -434,6 +692,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -434,6 +692,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0),
axis=[0], axis=[0],
) )
# pad to full matrix # pad to full matrix
padding = tf.constant( padding = tf.constant(
[[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]] [[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]]
...@@ -441,6 +700,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -441,6 +700,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# create lower mask # create lower mask
mask_2d = tf.pad(mask_2d_upper, padding) mask_2d = tf.pad(mask_2d_upper, padding)
# combine with upper mask # combine with upper mask
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1]) mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
...@@ -456,7 +716,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -456,7 +716,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
return input_tensor return input_tensor
def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_overlap): def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_overlap):
"""Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. """Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors.
Returned tensor will be of the same shape as `attn_probs`""" Returned tensor will be of the same shape as `attn_probs`"""
...@@ -479,8 +738,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -479,8 +738,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
chunks_count = seq_len // window_overlap - 1 chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
chunked_attn_probs = tf.reshape( chunked_attn_probs = tf.reshape(
tf.transpose(attn_probs, (0, 2, 1, 3)), tf.transpose(attn_probs, (0, 2, 1, 3)),
( (
...@@ -498,15 +757,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -498,15 +757,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
# pad seq_len with w at the beginning of the sequence and another window overlap at the end # pad seq_len with w at the beginning of the sequence and another window overlap at the end
paddings = tf.constant([[0, 0], [window_overlap, window_overlap], [0, 0]], dtype=tf.dtypes.int32) paddings = tf.constant([[0, 0], [window_overlap, window_overlap], [0, 0]], dtype=tf.dtypes.int32)
padded_value = tf.pad(value, paddings, constant_values=-1) padded_value = tf.pad(value, paddings, constant_values=-1)
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
frame_size = 3 * window_overlap * head_dim frame_size = 3 * window_overlap * head_dim
frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count
chunked_value = tf.signal.frame( chunked_value = tf.signal.frame(
tf.reshape(padded_value, (batch_size * num_heads, -1)), tf.reshape(padded_value, (batch_size * num_heads, -1)),
frame_size, frame_size,
...@@ -524,12 +780,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -524,12 +780,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value) context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
context = tf.transpose( context = tf.transpose(
tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)),
(0, 2, 1, 3), (0, 2, 1, 3),
) )
return context return context
@staticmethod @staticmethod
...@@ -538,7 +794,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -538,7 +794,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
hidden_states_padded = tf.pad( hidden_states_padded = tf.pad(
hidden_states_padded, paddings hidden_states_padded, paddings
) # padding value is not important because it will be overwritten ) # padding value is not important because it will be overwritten
batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded) batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded)
hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length)) hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length))
...@@ -560,12 +815,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -560,12 +815,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
""" """
total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states) total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states)
paddings = tf.constant([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]]) paddings = tf.constant([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]])
chunked_hidden_states = tf.pad( chunked_hidden_states = tf.pad(
chunked_hidden_states, paddings chunked_hidden_states, paddings
) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
chunked_hidden_states = tf.reshape( chunked_hidden_states = tf.reshape(
chunked_hidden_states, (total_num_heads, num_chunks, -1) chunked_hidden_states, (total_num_heads, num_chunks, -1)
) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap ) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap
...@@ -577,6 +830,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -577,6 +830,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
(total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim), (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim),
) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap ) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
return chunked_hidden_states return chunked_hidden_states
@staticmethod @staticmethod
...@@ -588,7 +842,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -588,7 +842,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# define frame size and frame stride (similar to convolution) # define frame size and frame stride (similar to convolution)
frame_hop_size = window_overlap * hidden_dim frame_hop_size = window_overlap * hidden_dim
frame_size = 2 * frame_hop_size frame_size = 2 * frame_hop_size
hidden_states = tf.reshape(hidden_states, (batch_size, seq_length * hidden_dim)) hidden_states = tf.reshape(hidden_states, (batch_size, seq_length * hidden_dim))
# chunk with overlap # chunk with overlap
...@@ -651,6 +904,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -651,6 +904,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# select global key vectors # select global key vectors
global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero) global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero)
# create only global key vectors # create only global key vectors
key_vectors_only_global = tf.scatter_nd( key_vectors_only_global = tf.scatter_nd(
is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero,
...@@ -665,6 +919,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -665,6 +919,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# (batch_size, seq_len, num_heads, max_num_global_attn_indices) # (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors, key_vectors_only_global) attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors, key_vectors_only_global)
# (batch_size, max_num_global_attn_indices, seq_len, num_heads) # (batch_size, max_num_global_attn_indices, seq_len, num_heads)
attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2)) attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2))
mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple( mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
...@@ -703,6 +958,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -703,6 +958,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# select global value vectors # select global value vectors
global_value_vectors = tf.gather_nd(value_vectors, is_index_global_attn_nonzero) global_value_vectors = tf.gather_nd(value_vectors, is_index_global_attn_nonzero)
# create only global value vectors # create only global value vectors
value_vectors_only_global = tf.scatter_nd( value_vectors_only_global = tf.scatter_nd(
is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero,
...@@ -725,6 +981,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -725,6 +981,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value( attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
attn_probs_without_global, value_vectors, self.one_sided_attn_window_size attn_probs_without_global, value_vectors, self.one_sided_attn_window_size
) )
return attn_output_only_global + attn_output_without_global return attn_output_only_global + attn_output_without_global
def _compute_global_attn_output_from_hidden( def _compute_global_attn_output_from_hidden(
...@@ -755,7 +1012,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -755,7 +1012,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# normalize # normalize
global_query_vectors_only_global /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32)) global_query_vectors_only_global /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size) global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size)
global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size) global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size)
global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size) global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size)
...@@ -773,7 +1029,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -773,7 +1029,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
global_attn_scores, global_attn_scores,
(batch_size, self.num_heads, max_num_global_attn_indices, seq_len), (batch_size, self.num_heads, max_num_global_attn_indices, seq_len),
) )
global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3)) global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3))
mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple( mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
shape_list(global_attn_scores_trans)[-2:] shape_list(global_attn_scores_trans)[-2:]
...@@ -791,7 +1046,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -791,7 +1046,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# mask global attn scores # mask global attn scores
attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores)) attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores))
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores) global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
global_attn_scores = tf.reshape( global_attn_scores = tf.reshape(
global_attn_scores, global_attn_scores,
(batch_size * self.num_heads, max_num_global_attn_indices, seq_len), (batch_size * self.num_heads, max_num_global_attn_indices, seq_len),
...@@ -828,10 +1082,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -828,10 +1082,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
# overwrite values with global attention # overwrite values with global attention
attn_output = tf.tensor_scatter_nd_update( attn_output = tf.tensor_scatter_nd_update(
attn_output, is_index_global_attn_nonzero, nonzero_global_attn_output attn_output, is_index_global_attn_nonzero, nonzero_global_attn_output
) )
return attn_output return attn_output
def reshape_and_transpose(self, vector, batch_size): def reshape_and_transpose(self, vector, batch_size):
...@@ -847,8 +1101,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -847,8 +1101,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
class TFLongformerAttention(tf.keras.layers.Layer): class TFLongformerAttention(tf.keras.layers.Layer):
def __init__(self, config, layer_id=0, **kwargs): def __init__(self, config, layer_id=0, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.self_attention = TFLongformerSelfAttention(config, layer_id, name="self") self.self_attention = TFLongformerSelfAttention(config, layer_id, name="self")
self.dense_output = TFBertSelfOutput(config, name="output") self.dense_output = TFLongformerSelfOutput(config, name="output")
def prune_heads(self, heads): def prune_heads(self, heads):
raise NotImplementedError raise NotImplementedError
...@@ -868,17 +1123,18 @@ class TFLongformerAttention(tf.keras.layers.Layer): ...@@ -868,17 +1123,18 @@ class TFLongformerAttention(tf.keras.layers.Layer):
training=training, training=training,
) )
attention_output = self.dense_output(self_outputs[0], hidden_states, training=training) attention_output = self.dense_output(self_outputs[0], hidden_states, training=training)
outputs = (attention_output,) + self_outputs[1:] outputs = (attention_output,) + self_outputs[1:]
return outputs return outputs
class TFLongformerLayer(tf.keras.layers.Layer): class TFLongformerLayer(tf.keras.layers.Layer):
def __init__(self, config, layer_id=0, **kwargs): def __init__(self, config, layer_id=0, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.attention = TFLongformerAttention(config, layer_id, name="attention") self.attention = TFLongformerAttention(config, layer_id, name="attention")
self.intermediate = TFBertIntermediate(config, name="intermediate") self.intermediate = TFLongformerIntermediate(config, name="intermediate")
self.longformer_output = TFBertOutput(config, name="output") self.longformer_output = TFLongformerOutput(config, name="output")
def call(self, inputs, training=False): def call(self, inputs, training=False):
( (
...@@ -898,12 +1154,14 @@ class TFLongformerLayer(tf.keras.layers.Layer): ...@@ -898,12 +1154,14 @@ class TFLongformerLayer(tf.keras.layers.Layer):
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(attention_output)
layer_output = self.longformer_output(intermediate_output, attention_output, training=training) layer_output = self.longformer_output(intermediate_output, attention_output, training=training)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs return outputs
class TFLongformerEncoder(tf.keras.layers.Layer): class TFLongformerEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.layer = [ self.layer = [
...@@ -926,6 +1184,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer): ...@@ -926,6 +1184,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
): ):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
...@@ -954,6 +1213,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer): ...@@ -954,6 +1213,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput( return TFBaseModelOutput(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
...@@ -985,10 +1245,9 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -985,10 +1245,9 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
self.return_dict = config.use_return_dict self.return_dict = config.use_return_dict
self.pad_token_id = config.pad_token_id self.pad_token_id = config.pad_token_id
self.attention_window = config.attention_window self.attention_window = config.attention_window
self.embeddings = TFLongformerEmbeddings(config, name="embeddings")
self.embeddings = TFRobertaEmbeddings(config, name="embeddings")
self.encoder = TFLongformerEncoder(config, name="encoder") self.encoder = TFLongformerEncoder(config, name="encoder")
self.pooler = TFBertPooler(config, name="pooler") self.pooler = TFLongformerPooler(config, name="pooler")
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings return self.embeddings
...@@ -1084,6 +1343,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1084,6 +1343,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
is_index_masked = tf.math.less(attention_mask, 1) is_index_masked = tf.math.less(attention_mask, 1)
is_index_global_attn = tf.math.greater(attention_mask, 1) is_index_global_attn = tf.math.greater(attention_mask, 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn) is_global_attn = tf.math.reduce_any(is_index_global_attn)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, to_seq_length, 1, 1] # Sizes are [batch_size, to_seq_length, 1, 1]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
...@@ -1097,7 +1357,6 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1097,7 +1357,6 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0 extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
...@@ -1111,7 +1370,6 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1111,7 +1370,6 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
...@@ -1149,22 +1407,27 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1149,22 +1407,27 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
) )
assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}"
input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds) input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)
batch_size, seq_len = input_shape[:2] batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seq_len % attention_window) % attention_window padding_len = (attention_window - seq_len % attention_window) % attention_window
if padding_len > 0: if padding_len > 0:
logger.info( logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format( "Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
seq_len, seq_len + padding_len, attention_window seq_len, seq_len + padding_len, attention_window
) )
) )
paddings = tf.constant([[0, 0], [0, padding_len]]) paddings = tf.constant([[0, 0], [0, padding_len]])
if input_ids is not None: if input_ids is not None:
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
if position_ids is not None: if position_ids is not None:
# pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings # pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings
position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id) position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id)
if inputs_embeds is not None: if inputs_embeds is not None:
input_ids_padding = tf.fill((batch_size, padding_len), self.pad_token_id) input_ids_padding = tf.fill((batch_size, padding_len), self.pad_token_id)
inputs_embeds_padding = self.embeddings(input_ids_padding) inputs_embeds_padding = self.embeddings(input_ids_padding)
...@@ -1195,6 +1458,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1195,6 +1458,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
# simply use `global_attention_mask` as `attention_mask` # simply use `global_attention_mask` as `attention_mask`
# if no `attention_mask` is given # if no `attention_mask` is given
attention_mask = global_attention_mask + 1 attention_mask = global_attention_mask + 1
return attention_mask return attention_mask
...@@ -1339,11 +1603,13 @@ class TFLongformerModel(TFLongformerPreTrainedModel): ...@@ -1339,11 +1603,13 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.longformer = TFLongformerMainLayer(config, name="longformer") self.longformer = TFLongformerMainLayer(config, name="longformer")
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.longformer(inputs, **kwargs) outputs = self.longformer(inputs, **kwargs)
return outputs return outputs
...@@ -1356,7 +1622,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel ...@@ -1356,7 +1622,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.longformer = TFLongformerMainLayer(config, name="longformer") self.longformer = TFLongformerMainLayer(config, name="longformer")
self.lm_head = TFRobertaLMHead(config, self.longformer.embeddings, name="lm_head") self.lm_head = TFLongformerLMHead(config, self.longformer.embeddings, name="lm_head")
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head.decoder return self.lm_head.decoder
...@@ -1390,8 +1656,10 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel ...@@ -1390,8 +1656,10 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
""" """
return_dict = return_dict if return_dict is not None else self.longformer.return_dict return_dict = return_dict if return_dict is not None else self.longformer.return_dict
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9: if len(inputs) > 9:
inputs = inputs[:9] inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
...@@ -1409,14 +1677,13 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel ...@@ -1409,14 +1677,13 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output, training=training) prediction_scores = self.lm_head(sequence_output, training=training)
loss = None if labels is None else self.compute_loss(labels, prediction_scores) loss = None if labels is None else self.compute_loss(labels, prediction_scores)
if not return_dict: if not return_dict:
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TFMaskedLMOutput( return TFMaskedLMOutput(
...@@ -1435,8 +1702,8 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel ...@@ -1435,8 +1702,8 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss): class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.num_labels = config.num_labels
self.longformer = TFLongformerMainLayer(config, name="longformer") self.longformer = TFLongformerMainLayer(config, name="longformer")
self.qa_outputs = tf.keras.layers.Dense( self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, config.num_labels,
...@@ -1477,6 +1744,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn ...@@ -1477,6 +1744,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
Position outside of the sequence are not taken into account for computing the loss. Position outside of the sequence are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.longformer.return_dict return_dict = return_dict if return_dict is not None else self.longformer.return_dict
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
global_attention_mask = inputs[2] global_attention_mask = inputs[2]
...@@ -1520,15 +1788,13 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn ...@@ -1520,15 +1788,13 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = tf.split(logits, 2, axis=-1) start_logits, end_logits = tf.split(logits, 2, axis=-1)
start_logits = tf.squeeze(start_logits, axis=-1) start_logits = tf.squeeze(start_logits, axis=-1)
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions} labels = {"start_position": start_positions}
labels["end_position"] = end_positions labels["end_position"] = end_positions
...@@ -1536,6 +1802,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn ...@@ -1536,6 +1802,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
if not return_dict: if not return_dict:
output = (start_logits, end_logits) + outputs[2:] output = (start_logits, end_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TFQuestionAnsweringModelOutput( return TFQuestionAnsweringModelOutput(
......
...@@ -31,7 +31,6 @@ from .file_utils import ( ...@@ -31,7 +31,6 @@ from .file_utils import (
add_start_docstrings_to_callable, add_start_docstrings_to_callable,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_tf_bert import TFBertIntermediate
from .modeling_tf_outputs import ( from .modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPooling, TFBaseModelOutputWithPooling,
...@@ -63,11 +62,29 @@ _CONFIG_FOR_DOC = "MobileBertConfig" ...@@ -63,11 +62,29 @@ _CONFIG_FOR_DOC = "MobileBertConfig"
_TOKENIZER_FOR_DOC = "MobileBertTokenizer" _TOKENIZER_FOR_DOC = "MobileBertTokenizer"
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"mobilebert-uncased", "google/mobilebert-uncased",
# See all MobileBERT models at https://huggingface.co/models?filter=mobilebert # See all MobileBERT models at https://huggingface.co/models?filter=mobilebert
] ]
class TFMobileBertIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.intermediate_size, name="dense")
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class TFLayerNorm(tf.keras.layers.LayerNormalization): class TFLayerNorm(tf.keras.layers.LayerNormalization):
def __init__(self, feat_size, *args, **kwargs): def __init__(self, feat_size, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -353,12 +370,6 @@ class TFMobileBertAttention(tf.keras.layers.Layer): ...@@ -353,12 +370,6 @@ class TFMobileBertAttention(tf.keras.layers.Layer):
return outputs return outputs
class TFMobileBertIntermediate(TFBertIntermediate):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.dense = tf.keras.layers.Dense(config.intermediate_size, name="dense")
class TFOutputBottleneck(tf.keras.layers.Layer): class TFOutputBottleneck(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
......
...@@ -26,8 +26,8 @@ from .file_utils import ( ...@@ -26,8 +26,8 @@ from .file_utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_callable, add_start_docstrings_to_callable,
) )
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer
from .modeling_tf_outputs import ( from .modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPooling, TFBaseModelOutputWithPooling,
TFMaskedLMOutput, TFMaskedLMOutput,
TFMultipleChoiceModelOutput, TFMultipleChoiceModelOutput,
...@@ -64,14 +64,48 @@ TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -64,14 +64,48 @@ TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
class TFRobertaEmbeddings(TFBertEmbeddings): class TFRobertaEmbeddings(tf.keras.layers.Layer):
""" """
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
""" """
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
self.padding_idx = 1 self.padding_idx = 1
self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
self.initializer_range = config.initializer_range
self.position_embeddings = tf.keras.layers.Embedding(
config.max_position_embeddings,
config.hidden_size,
embeddings_initializer=get_initializer(self.initializer_range),
name="position_embeddings",
)
self.token_type_embeddings = tf.keras.layers.Embedding(
config.type_vocab_size,
config.hidden_size,
embeddings_initializer=get_initializer(self.initializer_range),
name="token_type_embeddings",
)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def build(self, input_shape):
"""Build shared word embedding layer """
with tf.name_scope("word_embeddings"):
# Create and initialize weights. The random normal initializer was chosen
# arbitrarily, and works well.
self.word_embeddings = self.add_weight(
"weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def create_position_ids_from_input_ids(self, x): def create_position_ids_from_input_ids(self, x):
"""Replace non-padding symbols with their position numbers. Position numbers begin at """Replace non-padding symbols with their position numbers. Position numbers begin at
...@@ -82,6 +116,7 @@ class TFRobertaEmbeddings(TFBertEmbeddings): ...@@ -82,6 +116,7 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
""" """
mask = tf.cast(tf.math.not_equal(x, self.padding_idx), dtype=tf.int32) mask = tf.cast(tf.math.not_equal(x, self.padding_idx), dtype=tf.int32)
incremental_indicies = tf.math.cumsum(mask, axis=1) * mask incremental_indicies = tf.math.cumsum(mask, axis=1) * mask
return incremental_indicies + self.padding_idx return incremental_indicies + self.padding_idx
def create_position_ids_from_inputs_embeds(self, inputs_embeds): def create_position_ids_from_inputs_embeds(self, inputs_embeds):
...@@ -91,10 +126,40 @@ class TFRobertaEmbeddings(TFBertEmbeddings): ...@@ -91,10 +126,40 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
:return tf.Tensor: :return tf.Tensor:
""" """
seq_length = shape_list(inputs_embeds)[1] seq_length = shape_list(inputs_embeds)[1]
position_ids = tf.range(self.padding_idx + 1, seq_length + self.padding_idx + 1, dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(self.padding_idx + 1, seq_length + self.padding_idx + 1, dtype=tf.int32)[tf.newaxis, :]
return position_ids return position_ids
def call(
self,
input_ids=None,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
mode="embedding",
training=False,
):
"""Get token embeddings of inputs.
Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
mode: string, a valid value is one of "embedding" and "linear".
Returns:
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
linear tensor, float32 with shape [batch_size, length, vocab_size].
Raises:
ValueError: if mode is not valid.
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
if mode == "embedding":
return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
elif mode == "linear":
return self._linear(input_ids)
else:
raise ValueError("mode {} is not valid.".format(mode))
def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False): def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
"""Applies embedding based on inputs tensor.""" """Applies embedding based on inputs tensor."""
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
...@@ -106,19 +171,429 @@ class TFRobertaEmbeddings(TFBertEmbeddings): ...@@ -106,19 +171,429 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
else: else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
return super()._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) if input_ids is not None:
input_shape = shape_list(input_ids)
else:
input_shape = shape_list(inputs_embeds)[:-1]
seq_length = input_shape[1]
@keras_serializable if position_ids is None:
class TFRobertaMainLayer(TFBertMainLayer): position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
"""
Same as TFBertMainLayer but uses TFRobertaEmbeddings. if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs_embeds is None:
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)
token_type_embeddings = tf.cast(self.token_type_embeddings(token_type_ids), inputs_embeds.dtype)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings, training=training)
return embeddings
def _linear(self, inputs):
"""Computes logits by running inputs through a linear layer.
Args:
inputs: A float32 tensor with shape [batch_size, length, hidden_size]
Returns:
float32 tensor with shape [batch_size, length, vocab_size].
""" """
batch_size = shape_list(inputs)[0]
length = shape_list(inputs)[1]
x = tf.reshape(inputs, [-1, self.hidden_size])
logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
return tf.reshape(logits, [batch_size, length, self.vocab_size])
# Copied from transformers.modeling_tf_bert.TFBertPooler
class TFRobertaPooler(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(config, **kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="dense",
)
def call(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
return pooled_output
# Copied from transformers.modeling_tf_bert.TFBertSelfAttention
class TFRobertaSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
assert config.hidden_size % config.num_attention_heads == 0
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
)
self.key = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
)
self.value = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = tf.matmul(
query_layer, key_layer, transpose_b=True
) # (batch size, num_heads, seq_len_q, seq_len_k)
dk = tf.cast(shape_list(key_layer)[-1], attention_scores.dtype) # scale attention_scores
attention_scores = attention_scores / tf.math.sqrt(dk)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs, training=training)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
context_layer = tf.reshape(
context_layer, (batch_size, -1, self.all_head_size)
) # (batch_size, seq_len_q, all_head_size)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.modeling_tf_bert.TFBertSelfOutput
class TFRobertaSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, hidden_states, input_tensor, training=False):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# Copied from transformers.modeling_tf_bert.TFBertAttention with Bert->Roberta
class TFRobertaAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.self_attention = TFRobertaSelfAttention(config, name="self")
self.dense_output = TFRobertaSelfOutput(config, name="output")
def prune_heads(self, heads):
raise NotImplementedError
def call(self, input_tensor, attention_mask, head_mask, output_attentions, training=False):
self_outputs = self.self_attention(
input_tensor, attention_mask, head_mask, output_attentions, training=training
)
attention_output = self.dense_output(self_outputs[0], input_tensor, training=training)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
# Copied from transformers.modeling_tf_bert.TFBertIntermediate
class TFRobertaIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.modeling_tf_bert.TFBertOutput
class TFRobertaOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, hidden_states, input_tensor, training=False):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# Copied from transformers.modeling_tf_bert.TFBertLayer with Bert->Roberta
class TFRobertaLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.attention = TFRobertaAttention(config, name="attention")
self.intermediate = TFRobertaIntermediate(config, name="intermediate")
self.bert_output = TFRobertaOutput(config, name="output")
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
attention_outputs = self.attention(
hidden_states, attention_mask, head_mask, output_attentions, training=training
)
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.bert_output(intermediate_output, attention_output, training=training)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs
# Copied from transformers.modeling_tf_bert.TFBertEncoder with Bert->Roberta
class TFRobertaEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.layer = [TFRobertaLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
def call(
self,
hidden_states,
attention_mask,
head_mask,
output_attentions,
output_hidden_states,
return_dict,
training=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states, attention_mask, head_mask[i], output_attentions, training=training
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
@keras_serializable
class TFRobertaMainLayer(tf.keras.layers.Layer):
config_class = RobertaConfig
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers
self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.return_dict = config.use_return_dict
self.encoder = TFRobertaEncoder(config, name="encoder")
self.pooler = TFRobertaPooler(config, name="pooler")
# The embeddings must be the last declaration in order to follow the weights order
self.embeddings = TFRobertaEmbeddings(config, name="embeddings") self.embeddings = TFRobertaEmbeddings(config, name="embeddings")
# Copied from transformers.modeling_tf_bert.TFBertMainLayer.get_input_embeddings
def get_input_embeddings(self):
return self.embeddings
# Copied from transformers.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
# Copied from transformers.modeling_tf_bert.TFBertMainLayer._prune_heads
def _prune_heads(self, heads_to_prune):
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
raise NotImplementedError
# Copied from transformers.modeling_tf_bert.TFBertMainLayer.call
def call(
self,
inputs,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
raise NotImplementedError
else:
head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
encoder_outputs = self.encoder(
embedding_output,
extended_attention_mask,
head_mask,
output_attentions,
output_hidden_states,
return_dict,
training=training,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
if not return_dict:
return (
sequence_output,
pooled_output,
) + encoder_outputs[1:]
return TFBaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class TFRobertaPreTrainedModel(TFPreTrainedModel): class TFRobertaPreTrainedModel(TFPreTrainedModel):
"""An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
...@@ -246,6 +721,7 @@ class TFRobertaLMHead(tf.keras.layers.Layer): ...@@ -246,6 +721,7 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs): def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
...@@ -259,6 +735,7 @@ class TFRobertaLMHead(tf.keras.layers.Layer): ...@@ -259,6 +735,7 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape) super().build(input_shape)
def call(self, features): def call(self, features):
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import itertools import itertools
import math
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -114,13 +113,12 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32): ...@@ -114,13 +113,12 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
return mask, attn_mask return mask, attn_mask
class TFMultiHeadAttention(tf.keras.layers.Layer): class TFXLMMultiHeadAttention(tf.keras.layers.Layer):
NEW_ID = itertools.count() NEW_ID = itertools.count()
def __init__(self, n_heads, dim, config, **kwargs): def __init__(self, n_heads, dim, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.layer_id = next(TFMultiHeadAttention.NEW_ID) self.layer_id = next(TFXLMMultiHeadAttention.NEW_ID)
self.dim = dim self.dim = dim
self.n_heads = n_heads self.n_heads = n_heads
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
...@@ -143,13 +141,15 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -143,13 +141,15 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
# Input is (bs, qlen, dim) # Input is (bs, qlen, dim)
# Mask is (bs, klen) (non-causal) or (bs, klen, klen) # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
bs, qlen, dim = shape_list(input) bs, qlen, dim = shape_list(input)
if kv is None: if kv is None:
klen = qlen if cache is None else cache["slen"] + qlen klen = qlen if cache is None else cache["slen"] + qlen
else: else:
klen = shape_list(kv)[1] klen = shape_list(kv)[1]
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim) # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
n_heads = self.n_heads dim_per_head = tf.math.divide(self.dim, self.n_heads)
dim_per_head = self.dim // n_heads dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen) mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen)
def shape(x): def shape(x):
...@@ -161,6 +161,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -161,6 +161,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head)) return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))
q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head) q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
if kv is None: if kv is None:
k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head) k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head) v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head)
...@@ -177,14 +178,17 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -177,14 +178,17 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head) v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head)
else: else:
k, v = cache[self.layer_id] k, v = cache[self.layer_id]
cache[self.layer_id] = (k, v) cache[self.layer_id] = (k, v)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head) q = tf.cast(q, dtype=tf.float32)
q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32))) # (bs, n_heads, qlen, dim_per_head)
k = tf.cast(k, dtype=q.dtype)
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen) scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen) mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen) # scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
mask = tf.cast(mask, dtype=scores.dtype)
scores = scores - 1e30 * (1.0 - mask) scores = scores - 1e30 * (1.0 - mask)
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen) weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen) weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
...@@ -194,16 +198,18 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -194,16 +198,18 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
context = unshape(context) # (bs, qlen, dim) context = unshape(context) # (bs, qlen, dim)
outputs = (self.out_lin(context),) outputs = (self.out_lin(context),)
if output_attentions: if output_attentions:
outputs = outputs + (weights,) outputs = outputs + (weights,)
return outputs return outputs
class TFTransformerFFN(tf.keras.layers.Layer): class TFXLMTransformerFFN(tf.keras.layers.Layer):
def __init__(self, in_dim, dim_hidden, out_dim, config, **kwargs): def __init__(self, in_dim, dim_hidden, out_dim, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.lin1 = tf.keras.layers.Dense(dim_hidden, kernel_initializer=get_initializer(config.init_std), name="lin1") self.lin1 = tf.keras.layers.Dense(dim_hidden, kernel_initializer=get_initializer(config.init_std), name="lin1")
self.lin2 = tf.keras.layers.Dense(out_dim, kernel_initializer=get_initializer(config.init_std), name="lin2") self.lin2 = tf.keras.layers.Dense(out_dim, kernel_initializer=get_initializer(config.init_std), name="lin2")
self.act = get_tf_activation("gelu") if config.gelu_activation else get_tf_activation("relu") self.act = get_tf_activation("gelu") if config.gelu_activation else get_tf_activation("relu")
...@@ -214,6 +220,7 @@ class TFTransformerFFN(tf.keras.layers.Layer): ...@@ -214,6 +220,7 @@ class TFTransformerFFN(tf.keras.layers.Layer):
x = self.act(x) x = self.act(x)
x = self.lin2(x) x = self.lin2(x)
x = self.dropout(x, training=training) x = self.dropout(x, training=training)
return x return x
...@@ -223,6 +230,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -223,6 +230,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.return_dict = config.use_return_dict self.return_dict = config.use_return_dict
...@@ -230,8 +238,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -230,8 +238,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# encoder / decoder, output layer # encoder / decoder, output layer
self.is_encoder = config.is_encoder self.is_encoder = config.is_encoder
self.is_decoder = not config.is_encoder self.is_decoder = not config.is_encoder
if self.is_decoder: if self.is_decoder:
raise NotImplementedError("Currently XLM can only be used as an encoder") raise NotImplementedError("Currently XLM can only be used as an encoder")
# self.with_output = with_output # self.with_output = with_output
self.causal = config.causal self.causal = config.causal
...@@ -257,16 +267,17 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -257,16 +267,17 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# embeddings # embeddings
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
self.attention_dropout = tf.keras.layers.Dropout(config.attention_dropout) self.attention_dropout = tf.keras.layers.Dropout(config.attention_dropout)
self.position_embeddings = tf.keras.layers.Embedding( self.position_embeddings = tf.keras.layers.Embedding(
config.max_position_embeddings, config.max_position_embeddings,
self.dim, self.dim,
embeddings_initializer=get_initializer(config.embed_init_std), embeddings_initializer=get_initializer(config.embed_init_std),
name="position_embeddings", name="position_embeddings",
) )
if config.sinusoidal_embeddings: if config.sinusoidal_embeddings:
raise NotImplementedError raise NotImplementedError
# create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight) # create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
if config.n_langs > 1 and config.use_lang_emb: if config.n_langs > 1 and config.use_lang_emb:
self.lang_embeddings = tf.keras.layers.Embedding( self.lang_embeddings = tf.keras.layers.Embedding(
self.n_langs, self.n_langs,
...@@ -274,6 +285,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -274,6 +285,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
embeddings_initializer=get_initializer(config.embed_init_std), embeddings_initializer=get_initializer(config.embed_init_std),
name="lang_embeddings", name="lang_embeddings",
) )
self.embeddings = TFSharedEmbeddings( self.embeddings = TFSharedEmbeddings(
self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings" self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings"
) # padding_idx=self.pad_index) ) # padding_idx=self.pad_index)
...@@ -290,7 +302,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -290,7 +302,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
for i in range(self.n_layers): for i in range(self.n_layers):
self.attentions.append( self.attentions.append(
TFMultiHeadAttention(self.n_heads, self.dim, config=config, name="attentions_._{}".format(i)) TFXLMMultiHeadAttention(self.n_heads, self.dim, config=config, name="attentions_._{}".format(i))
) )
self.layer_norm1.append( self.layer_norm1.append(
tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1_._{}".format(i)) tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1_._{}".format(i))
...@@ -299,7 +311,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -299,7 +311,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps)) # self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout)) # self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
self.ffns.append( self.ffns.append(
TFTransformerFFN(self.dim, self.hidden_dim, self.dim, config=config, name="ffns_._{}".format(i)) TFXLMTransformerFFN(self.dim, self.hidden_dim, self.dim, config=config, name="ffns_._{}".format(i))
) )
self.layer_norm2.append( self.layer_norm2.append(
tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2_._{}".format(i)) tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2_._{}".format(i))
...@@ -308,6 +320,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -308,6 +320,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
if hasattr(config, "pruned_heads"): if hasattr(config, "pruned_heads"):
pruned_heads = config.pruned_heads.copy().items() pruned_heads = config.pruned_heads.copy().items()
config.pruned_heads = {} config.pruned_heads = {}
for layer, heads in pruned_heads: for layer, heads in pruned_heads:
if self.attentions[int(layer)].n_heads == config.n_heads: if self.attentions[int(layer)].n_heads == config.n_heads:
self.prune_heads({int(layer): list(map(int, heads))}) self.prune_heads({int(layer): list(map(int, heads))})
...@@ -398,7 +411,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -398,7 +411,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# check inputs # check inputs
# assert shape_list(lengths)[0] == bs # assert shape_list(lengths)[0] == bs
tf.debugging.assert_equal(shape_list(lengths)[0], bs) tf.debugging.assert_equal(
shape_list(lengths)[0], bs
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
# assert lengths.max().item() <= slen # assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None) # assert (src_enc is None) == (src_len is None)
...@@ -416,13 +431,17 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -416,13 +431,17 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
position_ids = tf.expand_dims(tf.range(slen), axis=0) position_ids = tf.expand_dims(tf.range(slen), axis=0)
else: else:
# assert shape_list(position_ids) == [bs, slen] # (slen, bs) # assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(shape_list(position_ids), [bs, slen]) tf.debugging.assert_equal(
shape_list(position_ids), [bs, slen]
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
# position_ids = position_ids.transpose(0, 1) # position_ids = position_ids.transpose(0, 1)
# langs # langs
if langs is not None: if langs is not None:
# assert shape_list(langs) == [bs, slen] # (slen, bs) # assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(shape_list(langs), [bs, slen]) tf.debugging.assert_equal(
shape_list(langs), [bs, slen]
), f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched"
# langs = langs.transpose(0, 1) # langs = langs.transpose(0, 1)
# Prepare head mask if needed # Prepare head mask if needed
...@@ -455,6 +474,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -455,6 +474,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
tensor = tensor + self.lang_embeddings(langs) tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None: if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids) tensor = tensor + self.embeddings(token_type_ids)
tensor = self.layer_norm_emb(tensor) tensor = self.layer_norm_emb(tensor)
tensor = self.dropout(tensor, training=training) tensor = self.dropout(tensor, training=training)
tensor = tensor * mask[..., tf.newaxis] tensor = tensor * mask[..., tf.newaxis]
...@@ -462,6 +482,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -462,6 +482,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# transformer layers # transformer layers
hidden_states = () if output_hidden_states else None hidden_states = () if output_hidden_states else None
attentions = () if output_attentions else None attentions = () if output_attentions else None
for i in range(self.n_layers): for i in range(self.n_layers):
if output_hidden_states: if output_hidden_states:
hidden_states = hidden_states + (tensor,) hidden_states = hidden_states + (tensor,)
...@@ -471,8 +492,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -471,8 +492,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if output_attentions: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training) attn = self.dropout(attn, training=training)
tensor = tensor + attn tensor = tensor + attn
tensor = self.layer_norm1[i](tensor) tensor = self.layer_norm1[i](tensor)
...@@ -502,6 +525,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -502,6 +525,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
if not return_dict: if not return_dict:
return tuple(v for v in [tensor, hidden_states, attentions] if v is not None) return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions) return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
...@@ -691,9 +715,11 @@ class TFXLMPredLayer(tf.keras.layers.Layer): ...@@ -691,9 +715,11 @@ class TFXLMPredLayer(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs): def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.asm = config.asm self.asm = config.asm
self.n_words = config.n_words self.n_words = config.n_words
self.pad_index = config.pad_index self.pad_index = config.pad_index
if config.asm is False: if config.asm is False:
self.input_embeddings = input_embeddings self.input_embeddings = input_embeddings
else: else:
...@@ -709,11 +735,13 @@ class TFXLMPredLayer(tf.keras.layers.Layer): ...@@ -709,11 +735,13 @@ class TFXLMPredLayer(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
# The output weights are the same as the input embeddings, but there is an output-only bias for each token. # The output weights are the same as the input embeddings, but there is an output-only bias for each token.
self.bias = self.add_weight(shape=(self.n_words,), initializer="zeros", trainable=True, name="bias") self.bias = self.add_weight(shape=(self.n_words,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape) super().build(input_shape)
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias hidden_states = hidden_states + self.bias
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment