Unverified Commit d161ed16 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Update the TF models to remove their interdependencies (#7238)

* Refacto the models to remove their interdependencies

* Fix Flaubert model

* Fix Flaubert

* Fix XLM

* Fix Albert

* Fix Roberta

* Fix Albert

* Fix Flaubert

* Apply style + remove unused imports

* Fix Distilbert

* remove unused import

* fix Distilbert

* Fix Flaubert

* Apply style

* Fix Flaubert

* Add the copy comments for the check_copies script

* Fix MobileBert model name

* Address Morgan's comments

* Fix typo

* Oops typo
parent 0cffa424
......@@ -31,7 +31,6 @@ from .file_utils import (
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_tf_bert import TFBertSelfAttention
from .modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPooling,
......@@ -181,82 +180,6 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
return tf.reshape(logits, [batch_size, length, self.config.vocab_size])
class TFAlbertSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
assert (
config.hidden_size % config.num_attention_heads == 0
), f"Hidden size {config.hidden_size} not dividable by number of heads {config.num_attention_heads}"
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.output_attentions = config.output_attentions
self.query = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
)
self.key = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
)
self.value = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# Take the dot product between "query" and "key" to get the raw attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
# scale attention_scores
dk = tf.cast(shape_list(key_layer)[-1], tf.float32)
attention_scores = attention_scores / tf.math.sqrt(dk)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFAlbertModel call() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs, training=training)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
context_layer = tf.reshape(
context_layer, (batch_size, -1, self.all_head_size)
) # (batch_size, seq_len_q, all_head_size)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class TFAlbertSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
......@@ -273,7 +196,7 @@ class TFAlbertSelfOutput(tf.keras.layers.Layer):
return hidden_states
class TFAlbertAttention(TFBertSelfAttention):
class TFAlbertAttention(tf.keras.layers.Layer):
""" Contains the complete attention sublayer, including both dropouts and layer norm. """
def __init__(self, config, **kwargs):
......@@ -281,6 +204,19 @@ class TFAlbertAttention(TFBertSelfAttention):
self.hidden_size = config.hidden_size
self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads
assert config.hidden_size % config.num_attention_heads == 0
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
)
self.key = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
)
self.value = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
......@@ -290,6 +226,11 @@ class TFAlbertAttention(TFBertSelfAttention):
self.attention_dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
self.output_dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def transpose_for_scores(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3])
def prune_heads(self, heads):
raise NotImplementedError
......@@ -342,6 +283,7 @@ class TFAlbertAttention(TFBertSelfAttention):
# add attentions if we output them
outputs = (attention_output,) + self_outputs[1:]
return outputs
......
......@@ -93,6 +93,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
self.initializer_range = config.initializer_range
......@@ -124,6 +125,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def call(
......@@ -273,6 +275,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
class TFBertSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
......@@ -290,6 +293,7 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
class TFBertAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.self_attention = TFBertSelfAttention(config, name="self")
self.dense_output = TFBertSelfOutput(config, name="output")
......@@ -309,6 +313,7 @@ class TFBertAttention(tf.keras.layers.Layer):
class TFBertIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
......@@ -328,6 +333,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
class TFBertOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
......@@ -345,6 +351,7 @@ class TFBertOutput(tf.keras.layers.Layer):
class TFBertLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.attention = TFBertAttention(config, name="attention")
self.intermediate = TFBertIntermediate(config, name="intermediate")
self.bert_output = TFBertOutput(config, name="output")
......@@ -364,6 +371,7 @@ class TFBertLayer(tf.keras.layers.Layer):
class TFBertEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
def call(
......@@ -397,6 +405,7 @@ class TFBertEncoder(tf.keras.layers.Layer):
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
......@@ -405,6 +414,7 @@ class TFBertEncoder(tf.keras.layers.Layer):
class TFBertPooler(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
......@@ -424,6 +434,7 @@ class TFBertPooler(tf.keras.layers.Layer):
class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
......@@ -446,6 +457,7 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
class TFBertLMPredictionHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs)
self.vocab_size = config.vocab_size
self.transform = TFBertPredictionHeadTransform(config, name="transform")
......@@ -455,6 +467,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def call(self, hidden_states):
......@@ -468,6 +481,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
class TFBertMLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs)
self.predictions = TFBertLMPredictionHead(config, input_embeddings, name="predictions")
def call(self, sequence_output):
......@@ -479,6 +493,7 @@ class TFBertMLMHead(tf.keras.layers.Layer):
class TFBertNSPHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.seq_relationship = tf.keras.layers.Dense(
2, kernel_initializer=get_initializer(config.initializer_range), name="seq_relationship"
)
......@@ -495,6 +510,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers
self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions
......@@ -571,6 +587,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
......@@ -588,7 +605,6 @@ class TFBertMainLayer(tf.keras.layers.Layer):
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
......@@ -767,6 +783,7 @@ BERT_INPUTS_DOCSTRING = r"""
class TFBertModel(TFBertPreTrainedModel):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name="bert")
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......@@ -778,6 +795,7 @@ class TFBertModel(TFBertPreTrainedModel):
)
def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs)
return outputs
......@@ -818,7 +836,6 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.bert.return_dict
outputs = self.bert(inputs, **kwargs)
sequence_output, pooled_output = outputs[:2]
prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
seq_relationship_score = self.nsp(pooled_output)
......@@ -880,6 +897,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
in ``[0, ..., config.vocab_size]``
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
......@@ -902,7 +920,6 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=training)
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
if not return_dict:
......@@ -956,6 +973,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
Indices should be in ``[0, ..., config.vocab_size - 1]``.
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
......@@ -978,8 +996,8 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
sequence_output = outputs[0]
logits = self.mlm(sequence_output, training=training)
loss = None
if labels is not None:
# shift labels to the left and cut last logit token
logits = logits[:, :-1]
......@@ -1033,7 +1051,6 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.bert.return_dict
outputs = self.bert(inputs, **kwargs)
pooled_output = outputs[1]
seq_relationship_score = self.nsp(pooled_output)
......@@ -1055,8 +1072,8 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name="bert")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(
......@@ -1092,6 +1109,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
......@@ -1113,10 +1131,8 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output)
loss = None if labels is None else self.compute_loss(labels, logits)
if not return_dict:
......@@ -1208,6 +1224,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if input_ids is not None:
......@@ -1242,7 +1259,6 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
if not return_dict:
......@@ -1265,8 +1281,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name="bert")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(
......@@ -1300,6 +1316,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
Indices should be in ``[0, ..., config.num_labels - 1]``.
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
......@@ -1319,12 +1336,9 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
return_dict=return_dict,
training=training,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
if not return_dict:
......@@ -1347,8 +1361,8 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name="bert")
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
......@@ -1387,6 +1401,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
Position outside of the sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
start_positions = inputs[9] if len(inputs) > 9 else start_positions
end_positions = inputs[10] if len(inputs) > 10 else end_positions
......@@ -1408,15 +1423,13 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
return_dict=return_dict,
training=training,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = tf.split(logits, 2, axis=-1)
start_logits = tf.squeeze(start_logits, axis=-1)
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
......
......@@ -16,8 +16,6 @@
"""
import math
import tensorflow as tf
from .activations_tf import get_tf_activation
......@@ -217,9 +215,8 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
k_length = shape_list(key)[1]
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
# assert key.size() == value.size()
dim_per_head = self.dim // self.n_heads
dim_per_head = tf.math.divide(self.dim, self.n_heads)
dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
mask_reshape = [bs, 1, 1, k_length]
def shape(x):
......@@ -233,17 +230,16 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
q = tf.cast(q, dtype=tf.float32)
q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32)))
k = tf.cast(k, dtype=q.dtype)
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, q_length, k_length)
mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, q_length, k_length)
scores_dtype = scores.dtype
# calculate `scores` in `tf.float32` to avoid numeric overflow
scores = tf.cast(scores, dtype=tf.float32) - 1e30 * (1.0 - tf.cast(mask, dtype=tf.float32))
weights = tf.cast(tf.nn.softmax(scores, axis=-1), dtype=scores_dtype) # (bs, n_heads, qlen, klen)
mask = tf.cast(mask, dtype=scores.dtype)
scores = scores - 1e30 * (1.0 - mask)
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
# Mask heads if we want to
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -31,7 +31,6 @@ from .file_utils import (
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_tf_bert import TFBertIntermediate
from .modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPooling,
......@@ -63,11 +62,29 @@ _CONFIG_FOR_DOC = "MobileBertConfig"
_TOKENIZER_FOR_DOC = "MobileBertTokenizer"
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"mobilebert-uncased",
"google/mobilebert-uncased",
# See all MobileBERT models at https://huggingface.co/models?filter=mobilebert
]
class TFMobileBertIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.intermediate_size, name="dense")
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class TFLayerNorm(tf.keras.layers.LayerNormalization):
def __init__(self, feat_size, *args, **kwargs):
super().__init__(*args, **kwargs)
......@@ -353,12 +370,6 @@ class TFMobileBertAttention(tf.keras.layers.Layer):
return outputs
class TFMobileBertIntermediate(TFBertIntermediate):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.dense = tf.keras.layers.Dense(config.intermediate_size, name="dense")
class TFOutputBottleneck(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
......
This diff is collapsed.
......@@ -17,7 +17,6 @@
import itertools
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
......@@ -114,13 +113,12 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
return mask, attn_mask
class TFMultiHeadAttention(tf.keras.layers.Layer):
class TFXLMMultiHeadAttention(tf.keras.layers.Layer):
NEW_ID = itertools.count()
def __init__(self, n_heads, dim, config, **kwargs):
super().__init__(**kwargs)
self.layer_id = next(TFMultiHeadAttention.NEW_ID)
self.layer_id = next(TFXLMMultiHeadAttention.NEW_ID)
self.dim = dim
self.n_heads = n_heads
self.output_attentions = config.output_attentions
......@@ -143,13 +141,15 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
# Input is (bs, qlen, dim)
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
bs, qlen, dim = shape_list(input)
if kv is None:
klen = qlen if cache is None else cache["slen"] + qlen
else:
klen = shape_list(kv)[1]
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
n_heads = self.n_heads
dim_per_head = self.dim // n_heads
dim_per_head = tf.math.divide(self.dim, self.n_heads)
dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen)
def shape(x):
......@@ -161,6 +161,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))
q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
if kv is None:
k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head)
......@@ -177,14 +178,17 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head)
else:
k, v = cache[self.layer_id]
cache[self.layer_id] = (k, v)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
q = tf.cast(q, dtype=tf.float32)
q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32))) # (bs, n_heads, qlen, dim_per_head)
k = tf.cast(k, dtype=q.dtype)
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
mask = tf.cast(mask, dtype=scores.dtype)
scores = scores - 1e30 * (1.0 - mask)
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
......@@ -194,16 +198,18 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
context = unshape(context) # (bs, qlen, dim)
outputs = (self.out_lin(context),)
if output_attentions:
outputs = outputs + (weights,)
return outputs
class TFTransformerFFN(tf.keras.layers.Layer):
class TFXLMTransformerFFN(tf.keras.layers.Layer):
def __init__(self, in_dim, dim_hidden, out_dim, config, **kwargs):
super().__init__(**kwargs)
self.lin1 = tf.keras.layers.Dense(dim_hidden, kernel_initializer=get_initializer(config.init_std), name="lin1")
self.lin2 = tf.keras.layers.Dense(out_dim, kernel_initializer=get_initializer(config.init_std), name="lin2")
self.act = get_tf_activation("gelu") if config.gelu_activation else get_tf_activation("relu")
......@@ -214,6 +220,7 @@ class TFTransformerFFN(tf.keras.layers.Layer):
x = self.act(x)
x = self.lin2(x)
x = self.dropout(x, training=training)
return x
......@@ -223,6 +230,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.return_dict = config.use_return_dict
......@@ -230,8 +238,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# encoder / decoder, output layer
self.is_encoder = config.is_encoder
self.is_decoder = not config.is_encoder
if self.is_decoder:
raise NotImplementedError("Currently XLM can only be used as an encoder")
# self.with_output = with_output
self.causal = config.causal
......@@ -257,16 +267,17 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# embeddings
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.attention_dropout = tf.keras.layers.Dropout(config.attention_dropout)
self.position_embeddings = tf.keras.layers.Embedding(
config.max_position_embeddings,
self.dim,
embeddings_initializer=get_initializer(config.embed_init_std),
name="position_embeddings",
)
if config.sinusoidal_embeddings:
raise NotImplementedError
# create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
if config.n_langs > 1 and config.use_lang_emb:
self.lang_embeddings = tf.keras.layers.Embedding(
self.n_langs,
......@@ -274,6 +285,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
embeddings_initializer=get_initializer(config.embed_init_std),
name="lang_embeddings",
)
self.embeddings = TFSharedEmbeddings(
self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings"
) # padding_idx=self.pad_index)
......@@ -290,7 +302,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
for i in range(self.n_layers):
self.attentions.append(
TFMultiHeadAttention(self.n_heads, self.dim, config=config, name="attentions_._{}".format(i))
TFXLMMultiHeadAttention(self.n_heads, self.dim, config=config, name="attentions_._{}".format(i))
)
self.layer_norm1.append(
tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1_._{}".format(i))
......@@ -299,7 +311,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
self.ffns.append(
TFTransformerFFN(self.dim, self.hidden_dim, self.dim, config=config, name="ffns_._{}".format(i))
TFXLMTransformerFFN(self.dim, self.hidden_dim, self.dim, config=config, name="ffns_._{}".format(i))
)
self.layer_norm2.append(
tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2_._{}".format(i))
......@@ -308,6 +320,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
if hasattr(config, "pruned_heads"):
pruned_heads = config.pruned_heads.copy().items()
config.pruned_heads = {}
for layer, heads in pruned_heads:
if self.attentions[int(layer)].n_heads == config.n_heads:
self.prune_heads({int(layer): list(map(int, heads))})
......@@ -398,7 +411,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# check inputs
# assert shape_list(lengths)[0] == bs
tf.debugging.assert_equal(shape_list(lengths)[0], bs)
tf.debugging.assert_equal(
shape_list(lengths)[0], bs
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
# assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None)
......@@ -416,13 +431,17 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
position_ids = tf.expand_dims(tf.range(slen), axis=0)
else:
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(shape_list(position_ids), [bs, slen])
tf.debugging.assert_equal(
shape_list(position_ids), [bs, slen]
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
# position_ids = position_ids.transpose(0, 1)
# langs
if langs is not None:
# assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(shape_list(langs), [bs, slen])
tf.debugging.assert_equal(
shape_list(langs), [bs, slen]
), f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched"
# langs = langs.transpose(0, 1)
# Prepare head mask if needed
......@@ -455,6 +474,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids)
tensor = self.layer_norm_emb(tensor)
tensor = self.dropout(tensor, training=training)
tensor = tensor * mask[..., tf.newaxis]
......@@ -462,6 +482,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# transformer layers
hidden_states = () if output_hidden_states else None
attentions = () if output_attentions else None
for i in range(self.n_layers):
if output_hidden_states:
hidden_states = hidden_states + (tensor,)
......@@ -471,8 +492,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training
)
attn = attn_outputs[0]
if output_attentions:
attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training)
tensor = tensor + attn
tensor = self.layer_norm1[i](tensor)
......@@ -502,6 +525,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
if not return_dict:
return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
......@@ -691,9 +715,11 @@ class TFXLMPredLayer(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs)
self.asm = config.asm
self.n_words = config.n_words
self.pad_index = config.pad_index
if config.asm is False:
self.input_embeddings = input_embeddings
else:
......@@ -709,11 +735,13 @@ class TFXLMPredLayer(tf.keras.layers.Layer):
def build(self, input_shape):
# The output weights are the same as the input embeddings, but there is an output-only bias for each token.
self.bias = self.add_weight(shape=(self.n_words,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment