Unverified Commit 31563e05 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Restore TF embeddings and attention layers to their previous version (#9890)

* Refacto BERT

* Restore all the concerned models

* Remove print

* Update template

* Apply Sylvain's and Morgan's comments

* Fix cast

* Put the cast inside call

* Remove cond in ebds

* Fix funnel

* Restore previous dot product (attention_scores) computation

* Add ConvBERT and BART

* Make all the S2S models ONNX compliant

* Fix test

* Fix check copies
parent 8bb52bd2
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import tensorflow as tf import tensorflow as tf
...@@ -73,157 +73,52 @@ TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -73,157 +73,52 @@ TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.models.bert.modeling_tf_bert.TFBertWordEmbeddings
class TFAlbertWordEmbeddings(tf.keras.layers.Layer):
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
return embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertTokenTypeEmbeddings
class TFAlbertTokenTypeEmbeddings(tf.keras.layers.Layer):
def __init__(self, type_vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.type_vocab_size = type_vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"type_vocab_size": self.type_vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, token_type_ids: tf.Tensor) -> tf.Tensor:
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(token_type_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(token_type_ids.shape.as_list() + [self.hidden_size])
return embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPositionEmbeddings
class TFAlbertPositionEmbeddings(tf.keras.layers.Layer):
def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, position_ids: tf.Tensor) -> tf.Tensor:
input_shape = shape_list(position_ids)
position_embeddings = self.position_embeddings[: input_shape[1], :]
return tf.broadcast_to(input=position_embeddings, shape=input_shape)
class TFAlbertEmbeddings(tf.keras.layers.Layer): class TFAlbertEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.word_embeddings = TFAlbertWordEmbeddings( self.vocab_size = config.vocab_size
vocab_size=config.vocab_size, self.type_vocab_size = config.type_vocab_size
hidden_size=config.embedding_size, self.embedding_size = config.embedding_size
initializer_range=config.initializer_range, self.max_position_embeddings = config.max_position_embeddings
name="word_embeddings", self.initializer_range = config.initializer_range
)
self.position_embeddings = TFAlbertPositionEmbeddings(
max_position_embeddings=config.max_position_embeddings,
hidden_size=config.embedding_size,
initializer_range=config.initializer_range,
name="position_embeddings",
)
self.token_type_embeddings = TFAlbertTokenTypeEmbeddings(
type_vocab_size=config.type_vocab_size,
hidden_size=config.embedding_size,
initializer_range=config.initializer_range,
name="token_type_embeddings",
)
self.embeddings_sum = tf.keras.layers.Add() self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.embedding_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.embedding_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.embedding_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call # Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
def call( def call(
self, self,
input_ids: tf.Tensor, input_ids: tf.Tensor = None,
position_ids: tf.Tensor, position_ids: tf.Tensor = None,
token_type_ids: tf.Tensor, token_type_ids: tf.Tensor = None,
inputs_embeds: tf.Tensor, inputs_embeds: tf.Tensor = None,
training: bool = False, training: bool = False,
) -> tf.Tensor: ) -> tf.Tensor:
""" """
...@@ -235,18 +130,19 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer): ...@@ -235,18 +130,19 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
inputs_embeds = self.word_embeddings(input_ids=input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1]
if token_type_ids is None: if token_type_ids is None:
input_shape = shape_list(tensor=inputs_embeds)[:-1]
token_type_ids = tf.fill(dims=input_shape, value=0) token_type_ids = tf.fill(dims=input_shape, value=0)
if position_ids is None: if position_ids is None:
position_embeds = self.position_embeddings(position_ids=inputs_embeds) position_ids = tf.range(start=0, limit=input_shape[-1])[tf.newaxis, :]
else:
position_embeds = self.position_embeddings(position_ids=position_ids)
token_type_embeds = self.token_type_embeddings(token_type_ids=token_type_ids) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds]) final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
final_embeddings = self.LayerNorm(inputs=final_embeddings) final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training) final_embeddings = self.dropout(inputs=final_embeddings, training=training)
...@@ -301,6 +197,7 @@ class TFAlbertAttention(tf.keras.layers.Layer): ...@@ -301,6 +197,7 @@ class TFAlbertAttention(tf.keras.layers.Layer):
self.output_dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.output_dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def transpose_for_scores(self, x, batch_size): def transpose_for_scores(self, x, batch_size):
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3]) return tf.transpose(x, perm=[0, 2, 1, 3])
...@@ -326,7 +223,7 @@ class TFAlbertAttention(tf.keras.layers.Layer): ...@@ -326,7 +223,7 @@ class TFAlbertAttention(tf.keras.layers.Layer):
attention_scores = attention_scores / tf.math.sqrt(dk) attention_scores = attention_scores / tf.math.sqrt(dk)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function) # Apply the attention mask is (precomputed for all layers in TFAlbertModel call() function)
attention_scores = attention_scores + attention_mask attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
...@@ -583,11 +480,11 @@ class TFAlbertMainLayer(tf.keras.layers.Layer): ...@@ -583,11 +480,11 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
) )
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings.weight = value self.embeddings.weight = value
self.embeddings.word_embeddings.vocab_size = shape_list(value)[0] self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -914,7 +811,7 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel): ...@@ -914,7 +811,7 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.albert = TFAlbertMainLayer(config, name="albert") self.albert = TFAlbertMainLayer(config, name="albert")
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings.word_embeddings, name="predictions") self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")
self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier") self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier")
def get_lm_head(self): def get_lm_head(self):
...@@ -1034,7 +931,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss) ...@@ -1034,7 +931,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert") self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings.word_embeddings, name="predictions") self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")
def get_lm_head(self): def get_lm_head(self):
return self.predictions return self.predictions
......
...@@ -92,17 +92,17 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i ...@@ -92,17 +92,17 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
if past_key_values_length > 0: if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1) mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
return tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0): def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
""" """
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
""" """
bsz, src_len = shape_list(mask) src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
expanded_mask = tf.cast(tf.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)), tf.float32)
return (1.0 - expanded_mask) * LARGE_NEGATIVE return (1.0 - expanded_mask) * LARGE_NEGATIVE
......
...@@ -15,9 +15,10 @@ ...@@ -15,9 +15,10 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 BERT model. """ """ TF 2.0 BERT model. """
import math
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -127,153 +128,51 @@ class TFBertPreTrainingLoss: ...@@ -127,153 +128,51 @@ class TFBertPreTrainingLoss:
return masked_lm_loss + next_sentence_loss return masked_lm_loss + next_sentence_loss
class TFBertWordEmbeddings(tf.keras.layers.Layer):
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
return embeddings
class TFBertTokenTypeEmbeddings(tf.keras.layers.Layer):
def __init__(self, type_vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.type_vocab_size = type_vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"type_vocab_size": self.type_vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, token_type_ids: tf.Tensor) -> tf.Tensor:
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(token_type_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(token_type_ids.shape.as_list() + [self.hidden_size])
return embeddings
class TFBertPositionEmbeddings(tf.keras.layers.Layer):
def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, position_ids: tf.Tensor) -> tf.Tensor:
input_shape = shape_list(position_ids)
position_embeddings = self.position_embeddings[: input_shape[1], :]
return tf.broadcast_to(input=position_embeddings, shape=input_shape)
class TFBertEmbeddings(tf.keras.layers.Layer): class TFBertEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config: BertConfig, **kwargs): def __init__(self, config: BertConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.word_embeddings = TFBertWordEmbeddings( self.vocab_size = config.vocab_size
vocab_size=config.vocab_size, self.type_vocab_size = config.type_vocab_size
hidden_size=config.hidden_size, self.hidden_size = config.hidden_size
initializer_range=config.initializer_range, self.max_position_embeddings = config.max_position_embeddings
name="word_embeddings", self.initializer_range = config.initializer_range
)
self.position_embeddings = TFBertPositionEmbeddings(
max_position_embeddings=config.max_position_embeddings,
hidden_size=config.hidden_size,
initializer_range=config.initializer_range,
name="position_embeddings",
)
self.token_type_embeddings = TFBertTokenTypeEmbeddings(
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
initializer_range=config.initializer_range,
name="token_type_embeddings",
)
self.embeddings_sum = tf.keras.layers.Add() self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape)
def call( def call(
self, self,
input_ids: tf.Tensor, input_ids: tf.Tensor = None,
position_ids: tf.Tensor, position_ids: tf.Tensor = None,
token_type_ids: tf.Tensor, token_type_ids: tf.Tensor = None,
inputs_embeds: tf.Tensor, inputs_embeds: tf.Tensor = None,
training: bool = False, training: bool = False,
) -> tf.Tensor: ) -> tf.Tensor:
""" """
...@@ -285,18 +184,19 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -285,18 +184,19 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1]
if token_type_ids is None: if token_type_ids is None:
input_shape = shape_list(inputs_embeds)[:-1]
token_type_ids = tf.fill(dims=input_shape, value=0) token_type_ids = tf.fill(dims=input_shape, value=0)
if position_ids is None: if position_ids is None:
position_embeds = self.position_embeddings(inputs_embeds) position_ids = tf.range(start=0, limit=input_shape[-1])[tf.newaxis, :]
else:
position_embeds = self.position_embeddings(position_ids)
token_type_embeds = self.token_type_embeddings(token_type_ids) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds]) final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
final_embeddings = self.LayerNorm(inputs=final_embeddings) final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training) final_embeddings = self.dropout(inputs=final_embeddings, training=training)
...@@ -314,31 +214,29 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -314,31 +214,29 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
f"of attention heads ({config.num_attention_heads})" f"of attention heads ({config.num_attention_heads})"
) )
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
self.query = tf.keras.layers.experimental.EinsumDense( self.query = tf.keras.layers.Dense(
equation="abc,cde->abde", units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="query",
) )
self.key = tf.keras.layers.experimental.EinsumDense( self.key = tf.keras.layers.Dense(
equation="abc,cde->abde", units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="key",
) )
self.value = tf.keras.layers.experimental.EinsumDense( self.value = tf.keras.layers.Dense(
equation="abc,cde->abde", units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="value",
) )
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
return tf.transpose(tensor, perm=[0, 2, 1, 3])
def call( def call(
self, self,
hidden_states: tf.Tensor, hidden_states: tf.Tensor,
...@@ -347,15 +245,20 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -347,15 +245,20 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
output_attentions: bool, output_attentions: bool,
training: bool = False, training: bool = False,
) -> Tuple[tf.Tensor]: ) -> Tuple[tf.Tensor]:
query_layer = self.query(inputs=hidden_states) batch_size = shape_list(hidden_states)[0]
key_layer = self.key(inputs=hidden_states) mixed_query_layer = self.query(inputs=hidden_states)
value_layer = self.value(inputs=hidden_states) mixed_key_layer = self.key(inputs=hidden_states)
mixed_value_layer = self.value(inputs=hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
dk = tf.cast(self.attention_head_size, dtype=query_layer.dtype) # (batch size, num_heads, seq_len_q, seq_len_k)
query_layer = tf.multiply(query_layer, tf.math.rsqrt(dk)) attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
attention_scores = tf.einsum("aecd,abcd->acbe", key_layer, query_layer) dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
attention_scores = tf.divide(attention_scores, dk)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function) # Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
...@@ -372,7 +275,11 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -372,7 +275,11 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
if head_mask is not None: if head_mask is not None:
attention_probs = tf.multiply(attention_probs, head_mask) attention_probs = tf.multiply(attention_probs, head_mask)
attention_output = tf.einsum("acbe,aecd->abcd", attention_probs, value_layer) attention_output = tf.matmul(attention_probs, value_layer)
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
# (batch_size, seq_len_q, all_head_size)
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
return outputs return outputs
...@@ -382,21 +289,8 @@ class TFBertSelfOutput(tf.keras.layers.Layer): ...@@ -382,21 +289,8 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
def __init__(self, config: BertConfig, **kwargs): def __init__(self, config: BertConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0: self.dense = tf.keras.layers.Dense(
raise ValueError( units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
f"of attention heads ({config.num_attention_heads})"
)
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = config.num_attention_heads * self.attention_head_size
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abcd,cde->abe",
output_shape=(None, self.all_head_size),
bias_axes="e",
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
) )
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
...@@ -446,12 +340,8 @@ class TFBertIntermediate(tf.keras.layers.Layer): ...@@ -446,12 +340,8 @@ class TFBertIntermediate(tf.keras.layers.Layer):
def __init__(self, config: BertConfig, **kwargs): def __init__(self, config: BertConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.experimental.EinsumDense( self.dense = tf.keras.layers.Dense(
equation="abc,cd->abd", units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
output_shape=(None, config.intermediate_size),
bias_axes="d",
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
) )
if isinstance(config.hidden_act, str): if isinstance(config.hidden_act, str):
...@@ -470,12 +360,8 @@ class TFBertOutput(tf.keras.layers.Layer): ...@@ -470,12 +360,8 @@ class TFBertOutput(tf.keras.layers.Layer):
def __init__(self, config: BertConfig, **kwargs): def __init__(self, config: BertConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.experimental.EinsumDense( self.dense = tf.keras.layers.Dense(
equation="abc,cd->abd", units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
bias_axes="d",
output_shape=(None, config.hidden_size),
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
) )
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
...@@ -698,11 +584,11 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -698,11 +584,11 @@ class TFBertMainLayer(tf.keras.layers.Layer):
self.pooler = TFBertPooler(config, name="pooler") if add_pooling_layer else None self.pooler = TFBertPooler(config, name="pooler") if add_pooling_layer else None
def get_input_embeddings(self) -> tf.keras.layers.Layer: def get_input_embeddings(self) -> tf.keras.layers.Layer:
return self.embeddings.word_embeddings return self.embeddings
def set_input_embeddings(self, value: tf.Variable): def set_input_embeddings(self, value: tf.Variable):
self.embeddings.word_embeddings.weight = value self.embeddings.weight = value
self.embeddings.word_embeddings.vocab_size = shape_list(value)[0] self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -1041,7 +927,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss): ...@@ -1041,7 +927,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
self.bert = TFBertMainLayer(config, name="bert") self.bert = TFBertMainLayer(config, name="bert")
self.nsp = TFBertNSPHead(config, name="nsp___cls") self.nsp = TFBertNSPHead(config, name="nsp___cls")
self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings.word_embeddings, name="mlm___cls") self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings, name="mlm___cls")
def get_lm_head(self) -> tf.keras.layers.Layer: def get_lm_head(self) -> tf.keras.layers.Layer:
return self.mlm.predictions return self.mlm.predictions
...@@ -1165,7 +1051,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -1165,7 +1051,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
) )
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert") self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings.word_embeddings, name="mlm___cls") self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings, name="mlm___cls")
def get_lm_head(self) -> tf.keras.layers.Layer: def get_lm_head(self) -> tf.keras.layers.Layer:
return self.mlm.predictions return self.mlm.predictions
...@@ -1270,7 +1156,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1270,7 +1156,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
logger.warning("If you want to use `TFBertLMHeadModel` as a standalone, add `is_decoder=True.`") logger.warning("If you want to use `TFBertLMHeadModel` as a standalone, add `is_decoder=True.`")
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert") self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings.word_embeddings, name="mlm___cls") self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings, name="mlm___cls")
def get_lm_head(self) -> tf.keras.layers.Layer: def get_lm_head(self) -> tf.keras.layers.Layer:
return self.mlm.predictions return self.mlm.predictions
......
...@@ -96,7 +96,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i ...@@ -96,7 +96,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
if past_key_values_length > 0: if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1) mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
return tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask # Copied from transformers.models.bart.modeling_tf_bart._expand_mask
...@@ -104,10 +105,9 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values ...@@ -104,10 +105,9 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
""" """
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
""" """
bsz, src_len = shape_list(mask) src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
expanded_mask = tf.cast(tf.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)), tf.float32)
return (1.0 - expanded_mask) * LARGE_NEGATIVE return (1.0 - expanded_mask) * LARGE_NEGATIVE
......
...@@ -94,7 +94,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i ...@@ -94,7 +94,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
if past_key_values_length > 0: if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1) mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
return tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask # Copied from transformers.models.bart.modeling_tf_bart._expand_mask
...@@ -102,10 +103,9 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values ...@@ -102,10 +103,9 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
""" """
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
""" """
bsz, src_len = shape_list(mask) src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
expanded_mask = tf.cast(tf.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)), tf.float32)
return (1.0 - expanded_mask) * LARGE_NEGATIVE return (1.0 - expanded_mask) * LARGE_NEGATIVE
......
...@@ -62,148 +62,55 @@ TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -62,148 +62,55 @@ TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
class TFConvBertWordEmbeddings(tf.keras.layers.Layer): # Copied from transformers.models.albert.modeling_tf_albert.TFAlbertEmbeddings
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape=input_shape)
def get_config(self):
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids):
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size])
return embeddings
class TFConvBertTokenTypeEmbeddings(tf.keras.layers.Layer):
def __init__(self, type_vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.type_vocab_size = type_vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape=input_shape)
def get_config(self):
config = {
"type_vocab_size": self.type_vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, token_type_ids):
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=token_type_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=token_type_ids.shape.as_list() + [self.hidden_size])
return embeddings
class TFConvBertPositionEmbeddings(tf.keras.layers.Layer):
def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape)
def get_config(self):
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, position_ids):
input_shape = shape_list(tensor=position_ids)
position_embeddings = self.position_embeddings[: input_shape[1], :]
return tf.broadcast_to(input=position_embeddings, shape=input_shape)
class TFConvBertEmbeddings(tf.keras.layers.Layer): class TFConvBertEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.word_embeddings = TFConvBertWordEmbeddings( self.vocab_size = config.vocab_size
vocab_size=config.vocab_size, self.type_vocab_size = config.type_vocab_size
hidden_size=config.embedding_size, self.embedding_size = config.embedding_size
initializer_range=config.initializer_range, self.max_position_embeddings = config.max_position_embeddings
name="word_embeddings", self.initializer_range = config.initializer_range
)
self.position_embeddings = TFConvBertPositionEmbeddings(
max_position_embeddings=config.max_position_embeddings,
hidden_size=config.embedding_size,
initializer_range=config.initializer_range,
name="position_embeddings",
)
self.token_type_embeddings = TFConvBertTokenTypeEmbeddings(
type_vocab_size=config.type_vocab_size,
hidden_size=config.embedding_size,
initializer_range=config.initializer_range,
name="token_type_embeddings",
)
self.embeddings_sum = tf.keras.layers.Add() self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False): def build(self, input_shape: tf.TensorShape):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.embedding_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.embedding_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.embedding_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
def call(
self,
input_ids: tf.Tensor = None,
position_ids: tf.Tensor = None,
token_type_ids: tf.Tensor = None,
inputs_embeds: tf.Tensor = None,
training: bool = False,
) -> tf.Tensor:
""" """
Applies embedding based on inputs tensor. Applies embedding based on inputs tensor.
...@@ -213,18 +120,19 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer): ...@@ -213,18 +120,19 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
inputs_embeds = self.word_embeddings(input_ids=input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1]
if token_type_ids is None: if token_type_ids is None:
input_shape = shape_list(tensor=inputs_embeds)[:-1]
token_type_ids = tf.fill(dims=input_shape, value=0) token_type_ids = tf.fill(dims=input_shape, value=0)
if position_ids is None: if position_ids is None:
position_embeds = self.position_embeddings(position_ids=inputs_embeds) position_ids = tf.range(start=0, limit=input_shape[-1])[tf.newaxis, :]
else:
position_embeds = self.position_embeddings(position_ids=position_ids)
token_type_embeds = self.token_type_embeddings(token_type_ids=token_type_ids) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds]) final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
final_embeddings = self.LayerNorm(inputs=final_embeddings) final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training) final_embeddings = self.dropout(inputs=final_embeddings, training=training)
...@@ -296,6 +204,7 @@ class TFConvBertSelfAttention(tf.keras.layers.Layer): ...@@ -296,6 +204,7 @@ class TFConvBertSelfAttention(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x, batch_size): def transpose_for_scores(self, x, batch_size):
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3]) return tf.transpose(x, perm=[0, 2, 1, 3])
...@@ -315,18 +224,27 @@ class TFConvBertSelfAttention(tf.keras.layers.Layer): ...@@ -315,18 +224,27 @@ class TFConvBertSelfAttention(tf.keras.layers.Layer):
conv_kernel_layer = tf.reshape(conv_kernel_layer, [-1, self.conv_kernel_size, 1]) conv_kernel_layer = tf.reshape(conv_kernel_layer, [-1, self.conv_kernel_size, 1])
conv_kernel_layer = tf.nn.softmax(conv_kernel_layer, axis=1) conv_kernel_layer = tf.nn.softmax(conv_kernel_layer, axis=1)
paddings = tf.constant(
[
[
0,
0,
],
[int((self.conv_kernel_size - 1) / 2), int((self.conv_kernel_size - 1) / 2)],
[0, 0],
]
)
conv_out_layer = self.conv_out_layer(hidden_states) conv_out_layer = self.conv_out_layer(hidden_states)
conv_out_layer = tf.reshape(conv_out_layer, [batch_size, -1, self.all_head_size]) conv_out_layer = tf.reshape(conv_out_layer, [batch_size, -1, self.all_head_size])
conv_out_layer = tf.pad(conv_out_layer, paddings, "CONSTANT")
conv_out_layer = tf.reshape(
conv_out_layer, [batch_size, shape_list(mixed_query_layer)[1], self.all_head_size, 1] unfold_conv_out_layer = tf.stack(
) [
unfold_conv_out_layer = tf.image.extract_patches( tf.slice(conv_out_layer, [0, i, 0], [batch_size, shape_list(mixed_query_layer)[1], self.all_head_size])
images=conv_out_layer, for i in range(self.conv_kernel_size)
sizes=[1, self.conv_kernel_size, 1, 1], ],
strides=[1, 1, 1, 1], axis=-1,
rates=[1, 1, 1, 1],
padding="SAME",
) )
conv_out_layer = tf.reshape(unfold_conv_out_layer, [-1, self.attention_head_size, self.conv_kernel_size]) conv_out_layer = tf.reshape(unfold_conv_out_layer, [-1, self.attention_head_size, self.conv_kernel_size])
...@@ -601,11 +519,11 @@ class TFConvBertMainLayer(tf.keras.layers.Layer): ...@@ -601,11 +519,11 @@ class TFConvBertMainLayer(tf.keras.layers.Layer):
self.config = config self.config = config
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings.weight = value self.embeddings.weight = value
self.embeddings.word_embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = value.shape[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -953,9 +871,7 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL ...@@ -953,9 +871,7 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL
else: else:
self.activation = config.hidden_act self.activation = config.hidden_act
self.generator_lm_head = TFConvBertMaskedLMHead( self.generator_lm_head = TFConvBertMaskedLMHead(config, self.convbert.embeddings, name="generator_lm_head")
config, self.convbert.embeddings.word_embeddings, name="generator_lm_head"
)
def get_lm_head(self): def get_lm_head(self):
return self.generator_lm_head return self.generator_lm_head
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
""" """
import warnings import warnings
from typing import Any, Dict
import tensorflow as tf import tensorflow as tf
...@@ -68,81 +67,6 @@ TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -68,81 +67,6 @@ TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.models.bert.modeling_tf_bert.TFBertWordEmbeddings
class TFDistilBertWordEmbeddings(tf.keras.layers.Layer):
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
return embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPositionEmbeddings
class TFDistilBertPositionEmbeddings(tf.keras.layers.Layer):
def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, position_ids: tf.Tensor) -> tf.Tensor:
input_shape = shape_list(position_ids)
position_embeddings = self.position_embeddings[: input_shape[1], :]
return tf.broadcast_to(input=position_embeddings, shape=input_shape)
class TFEmbeddings(tf.keras.layers.Layer): class TFEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
...@@ -151,23 +75,29 @@ class TFEmbeddings(tf.keras.layers.Layer): ...@@ -151,23 +75,29 @@ class TFEmbeddings(tf.keras.layers.Layer):
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.dim = config.dim self.dim = config.dim
self.initializer_range = config.initializer_range self.initializer_range = config.initializer_range
self.max_position_embeddings = config.max_position_embeddings
self.word_embeddings = TFDistilBertWordEmbeddings(
vocab_size=config.vocab_size,
hidden_size=config.dim,
initializer_range=config.initializer_range,
name="word_embeddings",
)
self.position_embeddings = TFDistilBertPositionEmbeddings(
max_position_embeddings=config.max_position_embeddings,
hidden_size=config.dim,
initializer_range=config.initializer_range,
name="position_embeddings",
)
self.embeddings_sum = tf.keras.layers.Add() self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.dropout) self.dropout = tf.keras.layers.Dropout(rate=config.dropout)
def build(self, input_shape: tf.TensorShape):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.dim],
initializer=get_initializer(initializer_range=self.initializer_range),
)
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.dim],
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape)
def call(self, input_ids=None, position_ids=None, inputs_embeds=None, training=False): def call(self, input_ids=None, position_ids=None, inputs_embeds=None, training=False):
""" """
Applies embedding based on inputs tensor. Applies embedding based on inputs tensor.
...@@ -178,13 +108,15 @@ class TFEmbeddings(tf.keras.layers.Layer): ...@@ -178,13 +108,15 @@ class TFEmbeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
inputs_embeds = self.word_embeddings(input_ids=input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1]
if position_ids is None: if position_ids is None:
position_embeds = self.position_embeddings(position_ids=inputs_embeds) position_ids = tf.range(start=0, limit=input_shape[-1])[tf.newaxis, :]
else:
position_embeds = self.position_embeddings(position_ids=position_ids)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds]) final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds])
final_embeddings = self.LayerNorm(inputs=final_embeddings) final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training) final_embeddings = self.dropout(inputs=final_embeddings, training=training)
...@@ -422,11 +354,11 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): ...@@ -422,11 +354,11 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
self.transformer = TFTransformer(config, name="transformer") # Encoder self.transformer = TFTransformer(config, name="transformer") # Encoder
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings.weight = value self.embeddings.weight = value
self.embeddings.word_embeddings.vocab_size = value.shape[0] self.embeddings.vocab_size = value.shape[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError raise NotImplementedError
...@@ -716,9 +648,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel ...@@ -716,9 +648,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
) )
self.act = get_tf_activation("gelu") self.act = get_tf_activation("gelu")
self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm") self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm")
self.vocab_projector = TFDistilBertLMHead( self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector")
config, self.distilbert.embeddings.word_embeddings, name="vocab_projector"
)
def get_lm_head(self): def get_lm_head(self):
return self.vocab_projector return self.vocab_projector
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
# limitations under the License. # limitations under the License.
""" TF Electra model. """ """ TF Electra model. """
import math
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import tensorflow as tf import tensorflow as tf
...@@ -70,122 +71,6 @@ TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -70,122 +71,6 @@ TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.models.bert.modeling_tf_bert.TFBertWordEmbeddings
class TFElectraWordEmbeddings(tf.keras.layers.Layer):
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
return embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertTokenTypeEmbeddings
class TFElectraTokenTypeEmbeddings(tf.keras.layers.Layer):
def __init__(self, type_vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.type_vocab_size = type_vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"type_vocab_size": self.type_vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, token_type_ids: tf.Tensor) -> tf.Tensor:
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(token_type_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(token_type_ids.shape.as_list() + [self.hidden_size])
return embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPositionEmbeddings
class TFElectraPositionEmbeddings(tf.keras.layers.Layer):
def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, position_ids: tf.Tensor) -> tf.Tensor:
input_shape = shape_list(position_ids)
position_embeddings = self.position_embeddings[: input_shape[1], :]
return tf.broadcast_to(input=position_embeddings, shape=input_shape)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->Electra # Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->Electra
class TFElectraSelfAttention(tf.keras.layers.Layer): class TFElectraSelfAttention(tf.keras.layers.Layer):
def __init__(self, config: ElectraConfig, **kwargs): def __init__(self, config: ElectraConfig, **kwargs):
...@@ -197,31 +82,29 @@ class TFElectraSelfAttention(tf.keras.layers.Layer): ...@@ -197,31 +82,29 @@ class TFElectraSelfAttention(tf.keras.layers.Layer):
f"of attention heads ({config.num_attention_heads})" f"of attention heads ({config.num_attention_heads})"
) )
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
self.query = tf.keras.layers.experimental.EinsumDense( self.query = tf.keras.layers.Dense(
equation="abc,cde->abde", units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="query",
) )
self.key = tf.keras.layers.experimental.EinsumDense( self.key = tf.keras.layers.Dense(
equation="abc,cde->abde", units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="key",
) )
self.value = tf.keras.layers.experimental.EinsumDense( self.value = tf.keras.layers.Dense(
equation="abc,cde->abde", units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="value",
) )
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
return tf.transpose(tensor, perm=[0, 2, 1, 3])
def call( def call(
self, self,
hidden_states: tf.Tensor, hidden_states: tf.Tensor,
...@@ -230,15 +113,20 @@ class TFElectraSelfAttention(tf.keras.layers.Layer): ...@@ -230,15 +113,20 @@ class TFElectraSelfAttention(tf.keras.layers.Layer):
output_attentions: bool, output_attentions: bool,
training: bool = False, training: bool = False,
) -> Tuple[tf.Tensor]: ) -> Tuple[tf.Tensor]:
query_layer = self.query(inputs=hidden_states) batch_size = shape_list(hidden_states)[0]
key_layer = self.key(inputs=hidden_states) mixed_query_layer = self.query(inputs=hidden_states)
value_layer = self.value(inputs=hidden_states) mixed_key_layer = self.key(inputs=hidden_states)
mixed_value_layer = self.value(inputs=hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
dk = tf.cast(self.attention_head_size, dtype=query_layer.dtype) # (batch size, num_heads, seq_len_q, seq_len_k)
query_layer = tf.multiply(query_layer, tf.math.rsqrt(dk)) attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
attention_scores = tf.einsum("aecd,abcd->acbe", key_layer, query_layer) dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
attention_scores = tf.divide(attention_scores, dk)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFElectraModel call() function) # Apply the attention mask is (precomputed for all layers in TFElectraModel call() function)
...@@ -255,7 +143,11 @@ class TFElectraSelfAttention(tf.keras.layers.Layer): ...@@ -255,7 +143,11 @@ class TFElectraSelfAttention(tf.keras.layers.Layer):
if head_mask is not None: if head_mask is not None:
attention_probs = tf.multiply(attention_probs, head_mask) attention_probs = tf.multiply(attention_probs, head_mask)
attention_output = tf.einsum("acbe,aecd->abcd", attention_probs, value_layer) attention_output = tf.matmul(attention_probs, value_layer)
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
# (batch_size, seq_len_q, all_head_size)
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
return outputs return outputs
...@@ -266,21 +158,8 @@ class TFElectraSelfOutput(tf.keras.layers.Layer): ...@@ -266,21 +158,8 @@ class TFElectraSelfOutput(tf.keras.layers.Layer):
def __init__(self, config: ElectraConfig, **kwargs): def __init__(self, config: ElectraConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0: self.dense = tf.keras.layers.Dense(
raise ValueError( units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
f"of attention heads ({config.num_attention_heads})"
)
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = config.num_attention_heads * self.attention_head_size
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abcd,cde->abe",
output_shape=(None, self.all_head_size),
bias_axes="e",
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
) )
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
...@@ -332,12 +211,8 @@ class TFElectraIntermediate(tf.keras.layers.Layer): ...@@ -332,12 +211,8 @@ class TFElectraIntermediate(tf.keras.layers.Layer):
def __init__(self, config: ElectraConfig, **kwargs): def __init__(self, config: ElectraConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.experimental.EinsumDense( self.dense = tf.keras.layers.Dense(
equation="abc,cd->abd", units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
output_shape=(None, config.intermediate_size),
bias_axes="d",
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
) )
if isinstance(config.hidden_act, str): if isinstance(config.hidden_act, str):
...@@ -357,12 +232,8 @@ class TFElectraOutput(tf.keras.layers.Layer): ...@@ -357,12 +232,8 @@ class TFElectraOutput(tf.keras.layers.Layer):
def __init__(self, config: ElectraConfig, **kwargs): def __init__(self, config: ElectraConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.experimental.EinsumDense( self.dense = tf.keras.layers.Dense(
equation="abc,cd->abd", units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
bias_axes="d",
output_shape=(None, config.hidden_size),
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
) )
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
...@@ -485,35 +356,46 @@ class TFElectraEmbeddings(tf.keras.layers.Layer): ...@@ -485,35 +356,46 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.word_embeddings = TFElectraWordEmbeddings( self.vocab_size = config.vocab_size
vocab_size=config.vocab_size, self.type_vocab_size = config.type_vocab_size
hidden_size=config.embedding_size, self.embedding_size = config.embedding_size
initializer_range=config.initializer_range, self.max_position_embeddings = config.max_position_embeddings
name="word_embeddings", self.initializer_range = config.initializer_range
)
self.position_embeddings = TFElectraPositionEmbeddings(
max_position_embeddings=config.max_position_embeddings,
hidden_size=config.embedding_size,
initializer_range=config.initializer_range,
name="position_embeddings",
)
self.token_type_embeddings = TFElectraTokenTypeEmbeddings(
type_vocab_size=config.type_vocab_size,
hidden_size=config.embedding_size,
initializer_range=config.initializer_range,
name="token_type_embeddings",
)
self.embeddings_sum = tf.keras.layers.Add() self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.embedding_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.embedding_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.embedding_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call # Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
def call( def call(
self, self,
input_ids: tf.Tensor, input_ids: tf.Tensor = None,
position_ids: tf.Tensor, position_ids: tf.Tensor = None,
token_type_ids: tf.Tensor, token_type_ids: tf.Tensor = None,
inputs_embeds: tf.Tensor, inputs_embeds: tf.Tensor = None,
training: bool = False, training: bool = False,
) -> tf.Tensor: ) -> tf.Tensor:
""" """
...@@ -525,18 +407,19 @@ class TFElectraEmbeddings(tf.keras.layers.Layer): ...@@ -525,18 +407,19 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
inputs_embeds = self.word_embeddings(input_ids=input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1]
if token_type_ids is None: if token_type_ids is None:
input_shape = shape_list(tensor=inputs_embeds)[:-1]
token_type_ids = tf.fill(dims=input_shape, value=0) token_type_ids = tf.fill(dims=input_shape, value=0)
if position_ids is None: if position_ids is None:
position_embeds = self.position_embeddings(position_ids=inputs_embeds) position_ids = tf.range(start=0, limit=input_shape[-1])[tf.newaxis, :]
else:
position_embeds = self.position_embeddings(position_ids=position_ids)
token_type_embeds = self.token_type_embeddings(token_type_ids=token_type_ids) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds]) final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
final_embeddings = self.LayerNorm(inputs=final_embeddings) final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training) final_embeddings = self.dropout(inputs=final_embeddings, training=training)
...@@ -605,11 +488,11 @@ class TFElectraMainLayer(tf.keras.layers.Layer): ...@@ -605,11 +488,11 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
self.config = config self.config = config
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings.weight = value self.embeddings.weight = value
self.embeddings.word_embeddings.vocab_size = shape_list(value)[0] self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -1057,9 +940,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -1057,9 +940,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
else: else:
self.activation = config.hidden_act self.activation = config.hidden_act
self.generator_lm_head = TFElectraMaskedLMHead( self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head")
config, self.electra.embeddings.word_embeddings, name="generator_lm_head"
)
def get_lm_head(self): def get_lm_head(self):
return self.generator_lm_head return self.generator_lm_head
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import tensorflow as tf import tensorflow as tf
...@@ -74,61 +74,29 @@ TF_FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -74,61 +74,29 @@ TF_FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST = [
INF = 1e6 INF = 1e6
# Copied from transformers.models.bert.modeling_tf_bert.TFBertWordEmbeddings
class TFFunnelWordEmbeddings(tf.keras.layers.Layer):
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
return embeddings
class TFFunnelEmbeddings(tf.keras.layers.Layer): class TFFunnelEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.word_embeddings = TFFunnelWordEmbeddings( self.vocab_size = config.vocab_size
vocab_size=config.vocab_size, self.hidden_size = config.hidden_size
hidden_size=config.hidden_size, self.initializer_range = config.initializer_range
initializer_range=config.initializer_range,
name="word_embeddings",
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout)
def build(self, input_shape):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape)
def call(self, input_ids=None, inputs_embeds=None, training=False): def call(self, input_ids=None, inputs_embeds=None, training=False):
""" """
Applies embedding based on inputs tensor. Applies embedding based on inputs tensor.
...@@ -140,7 +108,7 @@ class TFFunnelEmbeddings(tf.keras.layers.Layer): ...@@ -140,7 +108,7 @@ class TFFunnelEmbeddings(tf.keras.layers.Layer):
assert not (input_ids is not None and inputs_embeds is not None) assert not (input_ids is not None and inputs_embeds is not None)
if input_ids is not None: if input_ids is not None:
inputs_embeds = self.word_embeddings(input_ids=input_ids) inputs_embeds = tf.gather(self.weight, input_ids)
final_embeddings = self.LayerNorm(inputs=inputs_embeds) final_embeddings = self.LayerNorm(inputs=inputs_embeds)
final_embeddings = self.dropout(inputs=final_embeddings, training=training) final_embeddings = self.dropout(inputs=final_embeddings, training=training)
...@@ -513,13 +481,15 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer): ...@@ -513,13 +481,15 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
# Shape batch_size x n_head x seq_len x 2 # Shape batch_size x n_head x seq_len x 2
token_type_bias = tf.einsum("bind,snd->bnis", q_head + r_s_bias, self.seg_embed) token_type_bias = tf.einsum("bind,snd->bnis", q_head + r_s_bias, self.seg_embed)
# Shape batch_size x n_head x seq_len x context_len # Shape batch_size x n_head x seq_len x context_len
new_shape = [batch_size, shape_list(q_head)[2], seq_len, context_len] token_type_mat = tf.tile(token_type_mat[:, None], [1, shape_list(q_head)[2], 1, 1])
token_type_mat = tf.broadcast_to(token_type_mat[:, None], new_shape) # token_type_mat = tf.broadcast_to(token_type_mat[:, None], new_shape)
# Shapes batch_size x n_head x seq_len # Shapes batch_size x n_head x seq_len
diff_token_type, same_token_type = tf.split(token_type_bias, 2, axis=-1) diff_token_type, same_token_type = tf.split(token_type_bias, 2, axis=-1)
# Shape batch_size x n_head x seq_len x context_len # Shape batch_size x n_head x seq_len x context_len
token_type_attn = tf.where( token_type_attn = tf.where(
token_type_mat, tf.broadcast_to(same_token_type, new_shape), tf.broadcast_to(diff_token_type, new_shape) token_type_mat,
tf.tile(same_token_type, [1, 1, 1, context_len]),
tf.tile(diff_token_type, [1, 1, 1, context_len]),
) )
if cls_mask is not None: if cls_mask is not None:
...@@ -773,11 +743,11 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer): ...@@ -773,11 +743,11 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
self.encoder = TFFunnelEncoder(config, name="encoder") self.encoder = TFFunnelEncoder(config, name="encoder")
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings.weight = value self.embeddings.weight = value
self.embeddings.word_embeddings.vocab_size = shape_list(value)[0] self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
...@@ -859,11 +829,11 @@ class TFFunnelMainLayer(tf.keras.layers.Layer): ...@@ -859,11 +829,11 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
self.decoder = TFFunnelDecoder(config, name="decoder") self.decoder = TFFunnelDecoder(config, name="decoder")
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings.weight = value self.embeddings.weight = value
self.embeddings.word_embeddings.vocab_size = shape_list(value)[0] self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
...@@ -1360,7 +1330,7 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss) ...@@ -1360,7 +1330,7 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.funnel = TFFunnelMainLayer(config, name="funnel") self.funnel = TFFunnelMainLayer(config, name="funnel")
self.lm_head = TFFunnelMaskedLMHead(config, self.funnel.embeddings.word_embeddings, name="lm_head") self.lm_head = TFFunnelMaskedLMHead(config, self.funnel.embeddings, name="lm_head")
def get_lm_head(self): def get_lm_head(self):
return self.lm_head return self.lm_head
......
...@@ -87,17 +87,17 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i ...@@ -87,17 +87,17 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
if past_key_values_length > 0: if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1) mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
return tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0): def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
""" """
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
""" """
bsz, src_len = shape_list(mask) src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
expanded_mask = tf.cast(tf.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)), tf.float32)
return (1.0 - expanded_mask) * LARGE_NEGATIVE return (1.0 - expanded_mask) * LARGE_NEGATIVE
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple from typing import Optional, Tuple
import tensorflow as tf import tensorflow as tf
...@@ -415,126 +415,6 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se ...@@ -415,126 +415,6 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
return attention_mask return attention_mask
# Copied from transformers.models.bert.modeling_tf_bert.TFBertWordEmbeddings
class TFLongformerWordEmbeddings(tf.keras.layers.Layer):
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
return embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertTokenTypeEmbeddings
class TFLongformerTokenTypeEmbeddings(tf.keras.layers.Layer):
def __init__(self, type_vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.type_vocab_size = type_vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"type_vocab_size": self.type_vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, token_type_ids: tf.Tensor) -> tf.Tensor:
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(token_type_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(token_type_ids.shape.as_list() + [self.hidden_size])
return embeddings
class TFLongformerPositionEmbeddings(tf.keras.layers.Layer):
def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self):
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, position_ids):
flat_position_ids = tf.reshape(tensor=position_ids, shape=[-1])
embeddings = tf.gather(params=self.position_embeddings, indices=flat_position_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(position_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(position_ids.shape.as_list() + [self.hidden_size])
return embeddings
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaLMHead with Roberta->Longformer # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaLMHead with Roberta->Longformer
class TFLongformerLMHead(tf.keras.layers.Layer): class TFLongformerLMHead(tf.keras.layers.Layer):
"""Longformer Head for masked language modeling.""" """Longformer Head for masked language modeling."""
...@@ -598,28 +478,39 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer): ...@@ -598,28 +478,39 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.padding_idx = 1 self.padding_idx = 1
self.word_embeddings = TFLongformerWordEmbeddings( self.vocab_size = config.vocab_size
vocab_size=config.vocab_size, self.type_vocab_size = config.type_vocab_size
hidden_size=config.hidden_size, self.hidden_size = config.hidden_size
initializer_range=config.initializer_range, self.max_position_embeddings = config.max_position_embeddings
name="word_embeddings", self.initializer_range = config.initializer_range
)
self.position_embeddings = TFLongformerPositionEmbeddings(
max_position_embeddings=config.max_position_embeddings,
hidden_size=config.hidden_size,
initializer_range=config.initializer_range,
name="position_embeddings",
)
self.token_type_embeddings = TFLongformerTokenTypeEmbeddings(
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
initializer_range=config.initializer_range,
name="token_type_embeddings",
)
self.embeddings_sum = tf.keras.layers.Add() self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape)
def create_position_ids_from_input_ids(self, input_ids): def create_position_ids_from_input_ids(self, input_ids):
""" """
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
...@@ -627,36 +518,13 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer): ...@@ -627,36 +518,13 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
Args: Args:
input_ids: tf.Tensor input_ids: tf.Tensor
Returns: tf.Tensor Returns: tf.Tensor
""" """
input_ids_shape = shape_list(tensor=input_ids)
# multiple choice has 3 dimensions
if len(input_ids_shape) == 3:
input_ids = tf.reshape(
tensor=input_ids, shape=(input_ids_shape[0] * input_ids_shape[1], input_ids_shape[2])
)
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
incremental_indices = tf.math.cumsum(mask, axis=1) * mask incremental_indices = tf.math.cumsum(mask, axis=1) * mask
return incremental_indices + self.padding_idx return incremental_indices + self.padding_idx
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
"""
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
Args:
inputs_embeds: tf.Tensor
Returns: tf.Tensor
"""
batch_size, seq_length = shape_list(tensor=inputs_embeds)[:2]
position_ids = tf.range(start=self.padding_idx + 1, limit=seq_length + self.padding_idx + 1)[tf.newaxis, :]
return tf.tile(input=position_ids, multiples=(batch_size, 1))
def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False): def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False):
""" """
Applies embedding based on inputs tensor. Applies embedding based on inputs tensor.
...@@ -667,10 +535,11 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer): ...@@ -667,10 +535,11 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
inputs_embeds = self.word_embeddings(input_ids=input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1]
if token_type_ids is None: if token_type_ids is None:
input_shape = shape_list(tensor=inputs_embeds)[:-1]
token_type_ids = tf.fill(dims=input_shape, value=0) token_type_ids = tf.fill(dims=input_shape, value=0)
if position_ids is None: if position_ids is None:
...@@ -678,10 +547,13 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer): ...@@ -678,10 +547,13 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
# Create the position ids from the input token ids. Any padded tokens remain padded. # Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids) position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids)
else: else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds=inputs_embeds) position_ids = tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1)[
tf.newaxis, :
]
position_ids = tf.tile(input=position_ids, multiples=(input_shape[0], 1))
position_embeds = self.position_embeddings(position_ids=position_ids) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
token_type_embeds = self.token_type_embeddings(token_type_ids=token_type_ids) token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds]) final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
final_embeddings = self.LayerNorm(inputs=final_embeddings) final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training) final_embeddings = self.dropout(inputs=final_embeddings, training=training)
...@@ -694,12 +566,8 @@ class TFLongformerIntermediate(tf.keras.layers.Layer): ...@@ -694,12 +566,8 @@ class TFLongformerIntermediate(tf.keras.layers.Layer):
def __init__(self, config: LongformerConfig, **kwargs): def __init__(self, config: LongformerConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.experimental.EinsumDense( self.dense = tf.keras.layers.Dense(
equation="abc,cd->abd", units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
output_shape=(None, config.intermediate_size),
bias_axes="d",
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
) )
if isinstance(config.hidden_act, str): if isinstance(config.hidden_act, str):
...@@ -719,12 +587,8 @@ class TFLongformerOutput(tf.keras.layers.Layer): ...@@ -719,12 +587,8 @@ class TFLongformerOutput(tf.keras.layers.Layer):
def __init__(self, config: LongformerConfig, **kwargs): def __init__(self, config: LongformerConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.experimental.EinsumDense( self.dense = tf.keras.layers.Dense(
equation="abc,cd->abd", units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
bias_axes="d",
output_shape=(None, config.hidden_size),
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
) )
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
...@@ -758,20 +622,21 @@ class TFLongformerPooler(tf.keras.layers.Layer): ...@@ -758,20 +622,21 @@ class TFLongformerPooler(tf.keras.layers.Layer):
return pooled_output return pooled_output
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Longformer
class TFLongformerSelfOutput(tf.keras.layers.Layer): class TFLongformerSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config: LongformerConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
) )
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states, input_tensor, training=False): def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
return hidden_states return hidden_states
...@@ -1676,11 +1541,11 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1676,11 +1541,11 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
self.pooler = TFLongformerPooler(config, name="pooler") if add_pooling_layer else None self.pooler = TFLongformerPooler(config, name="pooler") if add_pooling_layer else None
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings.weight = value self.embeddings.weight = value
self.embeddings.word_embeddings.vocab_size = shape_list(value)[0] self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -2119,7 +1984,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel ...@@ -2119,7 +1984,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer") self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
self.lm_head = TFLongformerLMHead(config, self.longformer.embeddings.word_embeddings, name="lm_head") self.lm_head = TFLongformerLMHead(config, self.longformer.embeddings, name="lm_head")
def get_lm_head(self): def get_lm_head(self):
return self.lm_head return self.lm_head
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import tensorflow as tf import tensorflow as tf
...@@ -177,150 +177,45 @@ class TFLxmertVisualFeatureEncoder(tf.keras.layers.Layer): ...@@ -177,150 +177,45 @@ class TFLxmertVisualFeatureEncoder(tf.keras.layers.Layer):
return output return output
# Copied from transformers.models.bert.modeling_tf_bert.TFBertWordEmbeddings
class TFLxmertWordEmbeddings(tf.keras.layers.Layer):
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
return embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertTokenTypeEmbeddings
class TFLxmertTokenTypeEmbeddings(tf.keras.layers.Layer):
def __init__(self, type_vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.type_vocab_size = type_vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"type_vocab_size": self.type_vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, token_type_ids: tf.Tensor) -> tf.Tensor:
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(token_type_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(token_type_ids.shape.as_list() + [self.hidden_size])
return embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPositionEmbeddings
class TFLxmertPositionEmbeddings(tf.keras.layers.Layer):
def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, position_ids: tf.Tensor) -> tf.Tensor:
input_shape = shape_list(position_ids)
position_embeddings = self.position_embeddings[: input_shape[1], :]
return tf.broadcast_to(input=position_embeddings, shape=input_shape)
class TFLxmertEmbeddings(tf.keras.layers.Layer): class TFLxmertEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.word_embeddings = TFLxmertWordEmbeddings( self.vocab_size = config.vocab_size
vocab_size=config.vocab_size, self.type_vocab_size = config.type_vocab_size
hidden_size=config.hidden_size, self.hidden_size = config.hidden_size
initializer_range=config.initializer_range, self.max_position_embeddings = config.max_position_embeddings
name="word_embeddings", self.initializer_range = config.initializer_range
)
self.position_embeddings = TFLxmertPositionEmbeddings(
max_position_embeddings=config.max_position_embeddings,
hidden_size=config.hidden_size,
initializer_range=config.initializer_range,
name="position_embeddings",
)
self.token_type_embeddings = TFLxmertTokenTypeEmbeddings(
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
initializer_range=config.initializer_range,
name="token_type_embeddings",
)
self.embeddings_sum = tf.keras.layers.Add() self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape)
def call(self, input_ids=None, token_type_ids=None, inputs_embeds=None, training=False): def call(self, input_ids=None, token_type_ids=None, inputs_embeds=None, training=False):
""" """
Applies embedding based on inputs tensor. Applies embedding based on inputs tensor.
...@@ -331,14 +226,17 @@ class TFLxmertEmbeddings(tf.keras.layers.Layer): ...@@ -331,14 +226,17 @@ class TFLxmertEmbeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
inputs_embeds = self.word_embeddings(input_ids=input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1]
if token_type_ids is None: if token_type_ids is None:
input_shape = shape_list(tensor=inputs_embeds)[:-1]
token_type_ids = tf.fill(dims=input_shape, value=0) token_type_ids = tf.fill(dims=input_shape, value=0)
position_embeds = self.position_embeddings(position_ids=inputs_embeds) position_ids = tf.range(start=0, limit=input_shape[-1])[tf.newaxis, :]
token_type_embeds = self.token_type_embeddings(token_type_ids=token_type_ids) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds]) final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
final_embeddings = self.LayerNorm(inputs=final_embeddings) final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training) final_embeddings = self.dropout(inputs=final_embeddings, training=training)
...@@ -379,6 +277,7 @@ class TFLxmertAttention(tf.keras.layers.Layer): ...@@ -379,6 +277,7 @@ class TFLxmertAttention(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x, batch_size): def transpose_for_scores(self, x, batch_size):
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3]) return tf.transpose(x, perm=[0, 2, 1, 3])
...@@ -764,11 +663,11 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): ...@@ -764,11 +663,11 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
self.config = config self.config = config
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings.weight = value self.embeddings.weight = value
self.embeddings.word_embeddings.vocab_size = shape_list(value)[0] self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError raise NotImplementedError
...@@ -1309,7 +1208,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1309,7 +1208,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
self.lxmert = TFLxmertMainLayer(config, name="lxmert") self.lxmert = TFLxmertMainLayer(config, name="lxmert")
# Pre-training heads # Pre-training heads
self.cls = TFLxmertPreTrainingHeads(config, self.lxmert.embeddings.word_embeddings, name="cls") self.cls = TFLxmertPreTrainingHeads(config, self.lxmert.embeddings, name="cls")
if self.task_obj_predict: if self.task_obj_predict:
self.obj_predict_head = TFLxmertVisualObjHead(config, name="obj_predict_head") self.obj_predict_head = TFLxmertVisualObjHead(config, name="obj_predict_head")
if self.task_qa: if self.task_qa:
......
...@@ -95,7 +95,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i ...@@ -95,7 +95,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
if past_key_values_length > 0: if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1) mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
return tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask # Copied from transformers.models.bart.modeling_tf_bart._expand_mask
...@@ -103,10 +104,9 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values ...@@ -103,10 +104,9 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
""" """
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
""" """
bsz, src_len = shape_list(mask) src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
expanded_mask = tf.cast(tf.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)), tf.float32)
return (1.0 - expanded_mask) * LARGE_NEGATIVE return (1.0 - expanded_mask) * LARGE_NEGATIVE
......
...@@ -95,7 +95,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i ...@@ -95,7 +95,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
if past_key_values_length > 0: if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1) mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
return tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask # Copied from transformers.models.bart.modeling_tf_bart._expand_mask
...@@ -103,10 +104,9 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values ...@@ -103,10 +104,9 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
""" """
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
""" """
bsz, src_len = shape_list(mask) src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
expanded_mask = tf.cast(tf.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)), tf.float32)
return (1.0 - expanded_mask) * LARGE_NEGATIVE return (1.0 - expanded_mask) * LARGE_NEGATIVE
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import tensorflow as tf import tensorflow as tf
...@@ -107,122 +107,6 @@ class TFNoNorm(tf.keras.layers.Layer): ...@@ -107,122 +107,6 @@ class TFNoNorm(tf.keras.layers.Layer):
NORM2FN = {"layer_norm": TFLayerNorm, "no_norm": TFNoNorm} NORM2FN = {"layer_norm": TFLayerNorm, "no_norm": TFNoNorm}
# Copied from transformers.models.bert.modeling_tf_bert.TFBertWordEmbeddings
class TFMobileBertWordEmbeddings(tf.keras.layers.Layer):
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
return embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertTokenTypeEmbeddings
class TFMobileBertTokenTypeEmbeddings(tf.keras.layers.Layer):
def __init__(self, type_vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.type_vocab_size = type_vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"type_vocab_size": self.type_vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, token_type_ids: tf.Tensor) -> tf.Tensor:
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(token_type_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(token_type_ids.shape.as_list() + [self.hidden_size])
return embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPositionEmbeddings
class TFMobileBertPositionEmbeddings(tf.keras.layers.Layer):
def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, position_ids: tf.Tensor) -> tf.Tensor:
input_shape = shape_list(position_ids)
position_embeddings = self.position_embeddings[: input_shape[1], :]
return tf.broadcast_to(input=position_embeddings, shape=input_shape)
class TFMobileBertEmbeddings(tf.keras.layers.Layer): class TFMobileBertEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
...@@ -231,25 +115,11 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer): ...@@ -231,25 +115,11 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
self.trigram_input = config.trigram_input self.trigram_input = config.trigram_input
self.embedding_size = config.embedding_size self.embedding_size = config.embedding_size
self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.word_embeddings = TFMobileBertWordEmbeddings( self.type_vocab_size = config.type_vocab_size
vocab_size=config.vocab_size, self.max_position_embeddings = config.max_position_embeddings
hidden_size=config.embedding_size, self.initializer_range = config.initializer_range
initializer_range=config.initializer_range,
name="word_embeddings",
)
self.position_embeddings = TFMobileBertPositionEmbeddings(
max_position_embeddings=config.max_position_embeddings,
hidden_size=config.hidden_size,
initializer_range=config.initializer_range,
name="position_embeddings",
)
self.token_type_embeddings = TFMobileBertTokenTypeEmbeddings(
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
initializer_range=config.initializer_range,
name="token_type_embeddings",
)
self.embeddings_sum = tf.keras.layers.Add() self.embeddings_sum = tf.keras.layers.Add()
self.embedding_transformation = tf.keras.layers.Dense(config.hidden_size, name="embedding_transformation") self.embedding_transformation = tf.keras.layers.Dense(config.hidden_size, name="embedding_transformation")
...@@ -260,6 +130,30 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer): ...@@ -260,6 +130,30 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
) )
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.embedding_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape)
def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False): def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False):
""" """
Applies embedding based on inputs tensor. Applies embedding based on inputs tensor.
...@@ -270,10 +164,11 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer): ...@@ -270,10 +164,11 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
inputs_embeds = self.word_embeddings(input_ids=input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1]
if token_type_ids is None: if token_type_ids is None:
input_shape = shape_list(tensor=inputs_embeds)[:-1]
token_type_ids = tf.fill(dims=input_shape, value=0) token_type_ids = tf.fill(dims=input_shape, value=0)
if self.trigram_input: if self.trigram_input:
...@@ -297,11 +192,11 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer): ...@@ -297,11 +192,11 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
inputs_embeds = self.embedding_transformation(inputs_embeds) inputs_embeds = self.embedding_transformation(inputs_embeds)
if position_ids is None: if position_ids is None:
position_embeds = self.position_embeddings(position_ids=inputs_embeds) position_ids = tf.range(start=0, limit=input_shape[-1])[tf.newaxis, :]
else:
position_embeds = self.position_embeddings(position_ids=position_ids)
token_type_embeds = self.token_type_embeddings(token_type_ids=token_type_ids) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds]) final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
final_embeddings = self.LayerNorm(inputs=final_embeddings) final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training) final_embeddings = self.dropout(inputs=final_embeddings, training=training)
...@@ -337,6 +232,7 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer): ...@@ -337,6 +232,7 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x, batch_size): def transpose_for_scores(self, x, batch_size):
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3]) return tf.transpose(x, perm=[0, 2, 1, 3])
...@@ -772,11 +668,11 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): ...@@ -772,11 +668,11 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
self.pooler = TFMobileBertPooler(config, name="pooler") if add_pooling_layer else None self.pooler = TFMobileBertPooler(config, name="pooler") if add_pooling_layer else None
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings.weight = value self.embeddings.weight = value
self.embeddings.word_embeddings.vocab_size = shape_list(value)[0] self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
import math import math
import warnings import warnings
from typing import Any, Dict
import tensorflow as tf import tensorflow as tf
...@@ -87,86 +86,6 @@ class TFMPNetPreTrainedModel(TFPreTrainedModel): ...@@ -87,86 +86,6 @@ class TFMPNetPreTrainedModel(TFPreTrainedModel):
return self.serving_output(output) return self.serving_output(output)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertWordEmbeddings
class TFMPNetWordEmbeddings(tf.keras.layers.Layer):
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
return embeddings
# Copied from transformers.models.longformer.modeling_tf_longformer.TFLongformerPositionEmbeddings
class TFMPNetPositionEmbeddings(tf.keras.layers.Layer):
def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self):
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, position_ids):
flat_position_ids = tf.reshape(tensor=position_ids, shape=[-1])
embeddings = tf.gather(params=self.position_embeddings, indices=flat_position_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(position_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(position_ids.shape.as_list() + [self.hidden_size])
return embeddings
class TFMPNetEmbeddings(tf.keras.layers.Layer): class TFMPNetEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position embeddings.""" """Construct the embeddings from word, position embeddings."""
...@@ -174,22 +93,31 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer): ...@@ -174,22 +93,31 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.padding_idx = 1 self.padding_idx = 1
self.word_embeddings = TFMPNetWordEmbeddings( self.vocab_size = config.vocab_size
vocab_size=config.vocab_size, self.hidden_size = config.hidden_size
hidden_size=config.hidden_size, self.max_position_embeddings = config.max_position_embeddings
initializer_range=config.initializer_range, self.initializer_range = config.initializer_range
name="word_embeddings",
)
self.position_embeddings = TFMPNetPositionEmbeddings(
max_position_embeddings=config.max_position_embeddings,
hidden_size=config.hidden_size,
initializer_range=config.initializer_range,
name="position_embeddings",
)
self.embeddings_sum = tf.keras.layers.Add() self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape)
def create_position_ids_from_input_ids(self, input_ids): def create_position_ids_from_input_ids(self, input_ids):
""" """
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
...@@ -197,36 +125,13 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer): ...@@ -197,36 +125,13 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer):
Args: Args:
input_ids: tf.Tensor input_ids: tf.Tensor
Returns: tf.Tensor Returns: tf.Tensor
""" """
input_ids_shape = shape_list(tensor=input_ids)
# multiple choice has 3 dimensions
if len(input_ids_shape) == 3:
input_ids = tf.reshape(
tensor=input_ids, shape=(input_ids_shape[0] * input_ids_shape[1], input_ids_shape[2])
)
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
incremental_indices = tf.math.cumsum(mask, axis=1) * mask incremental_indices = tf.math.cumsum(mask, axis=1) * mask
return incremental_indices + self.padding_idx return incremental_indices + self.padding_idx
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
"""
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
Args:
inputs_embeds: tf.Tensor
Returns: tf.Tensor
"""
batch_size, seq_length = shape_list(tensor=inputs_embeds)[:2]
position_ids = tf.range(start=self.padding_idx + 1, limit=seq_length + self.padding_idx + 1)[tf.newaxis, :]
return tf.tile(input=position_ids, multiples=(batch_size, 1))
def call(self, input_ids=None, position_ids=None, inputs_embeds=None, training=False): def call(self, input_ids=None, position_ids=None, inputs_embeds=None, training=False):
""" """
Applies embedding based on inputs tensor. Applies embedding based on inputs tensor.
...@@ -237,16 +142,21 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer): ...@@ -237,16 +142,21 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
inputs_embeds = self.word_embeddings(input_ids=input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1]
if position_ids is None: if position_ids is None:
if input_ids is not None: if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded. # Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids) position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids)
else: else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds=inputs_embeds) position_ids = tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1)[
tf.newaxis, :
]
position_ids = tf.tile(input=position_ids, multiples=(input_shape[0], 1))
position_embeds = self.position_embeddings(position_ids=position_ids) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds]) final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds])
final_embeddings = self.LayerNorm(inputs=final_embeddings) final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training) final_embeddings = self.dropout(inputs=final_embeddings, training=training)
...@@ -281,58 +191,55 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer): ...@@ -281,58 +191,55 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer):
if config.hidden_size % config.num_attention_heads != 0: if config.hidden_size % config.num_attention_heads != 0:
raise ValueError( raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number " "The hidden size (%d) is not a multiple of the number of attention "
f"of attention heads ({config.num_attention_heads})" "heads (%d)" % (config.hidden_size, config.num_attention_heads)
) )
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
assert config.hidden_size % config.num_attention_heads == 0
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
self.q = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde", self.q = tf.keras.layers.Dense(
output_shape=(None, config.num_attention_heads, self.attention_head_size), self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="q"
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="q",
) )
self.k = tf.keras.layers.experimental.EinsumDense( self.k = tf.keras.layers.Dense(
equation="abc,cde->abde", self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="k"
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="k",
) )
self.v = tf.keras.layers.experimental.EinsumDense( self.v = tf.keras.layers.Dense(
equation="abc,cde->abde", self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="v"
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="v",
) )
self.o = tf.keras.layers.experimental.EinsumDense( self.o = tf.keras.layers.Dense(
equation="abcd,cde->abe", config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="o"
output_shape=(None, self.all_head_size),
bias_axes="e",
kernel_initializer=get_initializer(config.initializer_range),
name="o",
) )
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x, batch_size):
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, hidden_states, attention_mask, head_mask, output_attentions, position_bias=None, training=False): def call(self, hidden_states, attention_mask, head_mask, output_attentions, position_bias=None, training=False):
batch_size = shape_list(hidden_states)[0]
q = self.q(hidden_states) q = self.q(hidden_states)
k = self.k(hidden_states) k = self.k(hidden_states)
v = self.v(hidden_states) v = self.v(hidden_states)
dk = tf.cast(self.attention_head_size, dtype=q.dtype) q = self.transpose_for_scores(q, batch_size)
q = tf.multiply(q, y=tf.math.rsqrt(dk)) k = self.transpose_for_scores(k, batch_size)
attention_scores = tf.einsum("aecd,abcd->acbe", k, q) v = self.transpose_for_scores(v, batch_size)
attention_scores = tf.matmul(q, k, transpose_b=True)
dk = tf.cast(shape_list(k)[-1], attention_scores.dtype)
attention_scores = attention_scores / tf.math.sqrt(dk)
# Apply relative position embedding (precomputed in MPNetEncoder) if provided. # Apply relative position embedding (precomputed in MPNetEncoder) if provided.
if position_bias is not None: if position_bias is not None:
attention_scores += position_bias attention_scores += position_bias
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFMPNetModel call() function)
attention_scores = attention_scores + attention_mask attention_scores = attention_scores + attention_mask
attention_probs = tf.nn.softmax(attention_scores, axis=-1) attention_probs = tf.nn.softmax(attention_scores, axis=-1)
...@@ -342,7 +249,9 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer): ...@@ -342,7 +249,9 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer):
if head_mask is not None: if head_mask is not None:
attention_probs = attention_probs * head_mask attention_probs = attention_probs * head_mask
c = tf.einsum("acbe,aecd->abcd", attention_probs, v) c = tf.matmul(attention_probs, v)
c = tf.transpose(c, perm=[0, 2, 1, 3])
c = tf.reshape(c, (batch_size, -1, self.all_head_size))
o = self.o(c) o = self.o(c)
outputs = (o, attention_probs) if output_attentions else (o,) outputs = (o, attention_probs) if output_attentions else (o,)
...@@ -374,12 +283,8 @@ class TFMPNetIntermediate(tf.keras.layers.Layer): ...@@ -374,12 +283,8 @@ class TFMPNetIntermediate(tf.keras.layers.Layer):
def __init__(self, config: MPNetConfig, **kwargs): def __init__(self, config: MPNetConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.experimental.EinsumDense( self.dense = tf.keras.layers.Dense(
equation="abc,cd->abd", units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
output_shape=(None, config.intermediate_size),
bias_axes="d",
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
) )
if isinstance(config.hidden_act, str): if isinstance(config.hidden_act, str):
...@@ -399,12 +304,8 @@ class TFMPNetOutput(tf.keras.layers.Layer): ...@@ -399,12 +304,8 @@ class TFMPNetOutput(tf.keras.layers.Layer):
def __init__(self, config: MPNetConfig, **kwargs): def __init__(self, config: MPNetConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.experimental.EinsumDense( self.dense = tf.keras.layers.Dense(
equation="abc,cd->abd", units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
bias_axes="d",
output_shape=(None, config.hidden_size),
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
) )
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
...@@ -565,12 +466,12 @@ class TFMPNetMainLayer(tf.keras.layers.Layer): ...@@ -565,12 +466,12 @@ class TFMPNetMainLayer(tf.keras.layers.Layer):
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings
def get_input_embeddings(self) -> tf.keras.layers.Layer: def get_input_embeddings(self) -> tf.keras.layers.Layer:
return self.embeddings.word_embeddings return self.embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
def set_input_embeddings(self, value: tf.Variable): def set_input_embeddings(self, value: tf.Variable):
self.embeddings.word_embeddings.weight = value self.embeddings.weight = value
self.embeddings.word_embeddings.vocab_size = shape_list(value)[0] self.embeddings.vocab_size = shape_list(value)[0]
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
...@@ -894,7 +795,7 @@ class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -894,7 +795,7 @@ class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.mpnet = TFMPNetMainLayer(config, name="mpnet") self.mpnet = TFMPNetMainLayer(config, name="mpnet")
self.lm_head = TFMPNetLMHead(config, self.mpnet.embeddings.word_embeddings, name="lm_head") self.lm_head = TFMPNetLMHead(config, self.mpnet.embeddings, name="lm_head")
def get_lm_head(self): def get_lm_head(self):
return self.lm_head return self.lm_head
......
...@@ -95,7 +95,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i ...@@ -95,7 +95,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
if past_key_values_length > 0: if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1) mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
return tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask # Copied from transformers.models.bart.modeling_tf_bart._expand_mask
...@@ -103,10 +104,9 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values ...@@ -103,10 +104,9 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
""" """
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
""" """
bsz, src_len = shape_list(mask) src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
expanded_mask = tf.cast(tf.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)), tf.float32)
return (1.0 - expanded_mask) * LARGE_NEGATIVE return (1.0 - expanded_mask) * LARGE_NEGATIVE
......
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 RoBERTa model. """ """ TF 2.0 RoBERTa model. """
import math
import warnings import warnings
from typing import Any, Dict, Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -68,127 +69,6 @@ TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -68,127 +69,6 @@ TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.models.bert.modeling_tf_bert.TFBertWordEmbeddings
class TFRobertaWordEmbeddings(tf.keras.layers.Layer):
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
return embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertTokenTypeEmbeddings
class TFRobertaTokenTypeEmbeddings(tf.keras.layers.Layer):
def __init__(self, type_vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.type_vocab_size = type_vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape: tf.TensorShape):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self) -> Dict[str, Any]:
config = {
"type_vocab_size": self.type_vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, token_type_ids: tf.Tensor) -> tf.Tensor:
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(token_type_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(token_type_ids.shape.as_list() + [self.hidden_size])
return embeddings
# Copied from transformers.models.longformer.modeling_tf_longformer.TFLongformerPositionEmbeddings
class TFRobertaPositionEmbeddings(tf.keras.layers.Layer):
def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self):
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, position_ids):
flat_position_ids = tf.reshape(tensor=position_ids, shape=[-1])
embeddings = tf.gather(params=self.position_embeddings, indices=flat_position_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(position_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(position_ids.shape.as_list() + [self.hidden_size])
return embeddings
class TFRobertaEmbeddings(tf.keras.layers.Layer): class TFRobertaEmbeddings(tf.keras.layers.Layer):
""" """
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
...@@ -198,28 +78,39 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer): ...@@ -198,28 +78,39 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.padding_idx = 1 self.padding_idx = 1
self.word_embeddings = TFRobertaWordEmbeddings( self.vocab_size = config.vocab_size
vocab_size=config.vocab_size, self.type_vocab_size = config.type_vocab_size
hidden_size=config.hidden_size, self.hidden_size = config.hidden_size
initializer_range=config.initializer_range, self.max_position_embeddings = config.max_position_embeddings
name="word_embeddings", self.initializer_range = config.initializer_range
)
self.position_embeddings = TFRobertaPositionEmbeddings(
max_position_embeddings=config.max_position_embeddings,
hidden_size=config.hidden_size,
initializer_range=config.initializer_range,
name="position_embeddings",
)
self.token_type_embeddings = TFRobertaTokenTypeEmbeddings(
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
initializer_range=config.initializer_range,
name="token_type_embeddings",
)
self.embeddings_sum = tf.keras.layers.Add() self.embeddings_sum = tf.keras.layers.Add()
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape)
def create_position_ids_from_input_ids(self, input_ids): def create_position_ids_from_input_ids(self, input_ids):
""" """
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
...@@ -227,36 +118,13 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer): ...@@ -227,36 +118,13 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
Args: Args:
input_ids: tf.Tensor input_ids: tf.Tensor
Returns: tf.Tensor Returns: tf.Tensor
""" """
input_ids_shape = shape_list(tensor=input_ids)
# multiple choice has 3 dimensions
if len(input_ids_shape) == 3:
input_ids = tf.reshape(
tensor=input_ids, shape=(input_ids_shape[0] * input_ids_shape[1], input_ids_shape[2])
)
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
incremental_indices = tf.math.cumsum(mask, axis=1) * mask incremental_indices = tf.math.cumsum(mask, axis=1) * mask
return incremental_indices + self.padding_idx return incremental_indices + self.padding_idx
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
"""
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
Args:
inputs_embeds: tf.Tensor
Returns: tf.Tensor
"""
batch_size, seq_length = shape_list(tensor=inputs_embeds)[:2]
position_ids = tf.range(start=self.padding_idx + 1, limit=seq_length + self.padding_idx + 1)[tf.newaxis, :]
return tf.tile(input=position_ids, multiples=(batch_size, 1))
def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False): def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False):
""" """
Applies embedding based on inputs tensor. Applies embedding based on inputs tensor.
...@@ -267,10 +135,11 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer): ...@@ -267,10 +135,11 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
inputs_embeds = self.word_embeddings(input_ids=input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1]
if token_type_ids is None: if token_type_ids is None:
input_shape = shape_list(tensor=inputs_embeds)[:-1]
token_type_ids = tf.fill(dims=input_shape, value=0) token_type_ids = tf.fill(dims=input_shape, value=0)
if position_ids is None: if position_ids is None:
...@@ -278,10 +147,13 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer): ...@@ -278,10 +147,13 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
# Create the position ids from the input token ids. Any padded tokens remain padded. # Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids) position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids)
else: else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds=inputs_embeds) position_ids = tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1)[
tf.newaxis, :
]
position_ids = tf.tile(input=position_ids, multiples=(input_shape[0], 1))
position_embeds = self.position_embeddings(position_ids=position_ids) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
token_type_embeds = self.token_type_embeddings(token_type_ids=token_type_ids) token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds]) final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
final_embeddings = self.LayerNorm(inputs=final_embeddings) final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training) final_embeddings = self.dropout(inputs=final_embeddings, training=training)
...@@ -321,31 +193,29 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer): ...@@ -321,31 +193,29 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
f"of attention heads ({config.num_attention_heads})" f"of attention heads ({config.num_attention_heads})"
) )
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
self.query = tf.keras.layers.experimental.EinsumDense( self.query = tf.keras.layers.Dense(
equation="abc,cde->abde", units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="query",
) )
self.key = tf.keras.layers.experimental.EinsumDense( self.key = tf.keras.layers.Dense(
equation="abc,cde->abde", units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="key",
) )
self.value = tf.keras.layers.experimental.EinsumDense( self.value = tf.keras.layers.Dense(
equation="abc,cde->abde", units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(config.initializer_range),
name="value",
) )
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
return tf.transpose(tensor, perm=[0, 2, 1, 3])
def call( def call(
self, self,
hidden_states: tf.Tensor, hidden_states: tf.Tensor,
...@@ -354,15 +224,20 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer): ...@@ -354,15 +224,20 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
output_attentions: bool, output_attentions: bool,
training: bool = False, training: bool = False,
) -> Tuple[tf.Tensor]: ) -> Tuple[tf.Tensor]:
query_layer = self.query(inputs=hidden_states) batch_size = shape_list(hidden_states)[0]
key_layer = self.key(inputs=hidden_states) mixed_query_layer = self.query(inputs=hidden_states)
value_layer = self.value(inputs=hidden_states) mixed_key_layer = self.key(inputs=hidden_states)
mixed_value_layer = self.value(inputs=hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
dk = tf.cast(self.attention_head_size, dtype=query_layer.dtype) # (batch size, num_heads, seq_len_q, seq_len_k)
query_layer = tf.multiply(query_layer, tf.math.rsqrt(dk)) attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
attention_scores = tf.einsum("aecd,abcd->acbe", key_layer, query_layer) dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
attention_scores = tf.divide(attention_scores, dk)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFRobertaModel call() function) # Apply the attention mask is (precomputed for all layers in TFRobertaModel call() function)
...@@ -379,7 +254,11 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer): ...@@ -379,7 +254,11 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
if head_mask is not None: if head_mask is not None:
attention_probs = tf.multiply(attention_probs, head_mask) attention_probs = tf.multiply(attention_probs, head_mask)
attention_output = tf.einsum("acbe,aecd->abcd", attention_probs, value_layer) attention_output = tf.matmul(attention_probs, value_layer)
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
# (batch_size, seq_len_q, all_head_size)
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
return outputs return outputs
...@@ -390,21 +269,8 @@ class TFRobertaSelfOutput(tf.keras.layers.Layer): ...@@ -390,21 +269,8 @@ class TFRobertaSelfOutput(tf.keras.layers.Layer):
def __init__(self, config: RobertaConfig, **kwargs): def __init__(self, config: RobertaConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0: self.dense = tf.keras.layers.Dense(
raise ValueError( units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
f"of attention heads ({config.num_attention_heads})"
)
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = config.num_attention_heads * self.attention_head_size
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abcd,cde->abe",
output_shape=(None, self.all_head_size),
bias_axes="e",
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
) )
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
...@@ -456,12 +322,8 @@ class TFRobertaIntermediate(tf.keras.layers.Layer): ...@@ -456,12 +322,8 @@ class TFRobertaIntermediate(tf.keras.layers.Layer):
def __init__(self, config: RobertaConfig, **kwargs): def __init__(self, config: RobertaConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.experimental.EinsumDense( self.dense = tf.keras.layers.Dense(
equation="abc,cd->abd", units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
output_shape=(None, config.intermediate_size),
bias_axes="d",
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
) )
if isinstance(config.hidden_act, str): if isinstance(config.hidden_act, str):
...@@ -481,12 +343,8 @@ class TFRobertaOutput(tf.keras.layers.Layer): ...@@ -481,12 +343,8 @@ class TFRobertaOutput(tf.keras.layers.Layer):
def __init__(self, config: RobertaConfig, **kwargs): def __init__(self, config: RobertaConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.dense = tf.keras.layers.experimental.EinsumDense( self.dense = tf.keras.layers.Dense(
equation="abc,cd->abd", units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
bias_axes="d",
output_shape=(None, config.hidden_size),
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
) )
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
...@@ -601,12 +459,12 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -601,12 +459,12 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings
def get_input_embeddings(self) -> tf.keras.layers.Layer: def get_input_embeddings(self) -> tf.keras.layers.Layer:
return self.embeddings.word_embeddings return self.embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
def set_input_embeddings(self, value: tf.Variable): def set_input_embeddings(self, value: tf.Variable):
self.embeddings.word_embeddings.weight = value self.embeddings.weight = value
self.embeddings.word_embeddings.vocab_size = shape_list(value)[0] self.embeddings.vocab_size = shape_list(value)[0]
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
...@@ -972,7 +830,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -972,7 +830,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta")
self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings.word_embeddings, name="lm_head") self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head")
def get_lm_head(self): def get_lm_head(self):
return self.lm_head return self.lm_head
......
...@@ -866,7 +866,8 @@ class TFModelTesterMixin: ...@@ -866,7 +866,8 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) inputs = copy.deepcopy(inputs_dict)
if not self.is_encoder_decoder: if not self.is_encoder_decoder:
input_ids = inputs["input_ids"] input_ids = inputs["input_ids"]
del inputs["input_ids"] del inputs["input_ids"]
...@@ -882,6 +883,8 @@ class TFModelTesterMixin: ...@@ -882,6 +883,8 @@ class TFModelTesterMixin:
inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids) inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids)
inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids) inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids)
inputs = self._prepare_for_class(inputs, model_class)
model(inputs) model(inputs)
def test_graph_mode_with_inputs_embeds(self): def test_graph_mode_with_inputs_embeds(self):
...@@ -890,7 +893,8 @@ class TFModelTesterMixin: ...@@ -890,7 +893,8 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) inputs = copy.deepcopy(inputs_dict)
if not self.is_encoder_decoder: if not self.is_encoder_decoder:
input_ids = inputs["input_ids"] input_ids = inputs["input_ids"]
del inputs["input_ids"] del inputs["input_ids"]
...@@ -906,6 +910,8 @@ class TFModelTesterMixin: ...@@ -906,6 +910,8 @@ class TFModelTesterMixin:
inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids) inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids)
inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids) inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids)
inputs = self._prepare_for_class(inputs, model_class)
@tf.function @tf.function
def run_in_graph_mode(): def run_in_graph_mode():
return model(inputs) return model(inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment