Unverified Commit 4adbdce5 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Clean TF Bert (#9788)

* Start cleaning BERT

* Clean BERT and all those depends of it

* Fix attribute name

* Apply style

* Apply Sylvain's comments

* Apply Lysandre's comments

* remove unused import
parent f0329ea5
......@@ -46,6 +46,10 @@ from .utils import logging
logger = logging.get_logger(__name__)
tf_logger = tf.get_logger()
TFModelInputType = Union[
List[tf.Tensor], List[np.ndarray], Dict[str, tf.Tensor], Dict[str, np.ndarray], np.ndarray, tf.Tensor
]
class TFModelUtilsMixin:
"""
......
......@@ -17,7 +17,7 @@
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple
import tensorflow as tf
......@@ -82,16 +82,16 @@ class TFAlbertWordEmbeddings(tf.keras.layers.Layer):
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape=input_shape)
super().build(input_shape)
def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
......@@ -101,14 +101,14 @@ class TFAlbertWordEmbeddings(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids):
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
return embeddings
......@@ -122,16 +122,16 @@ class TFAlbertTokenTypeEmbeddings(tf.keras.layers.Layer):
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape=input_shape)
super().build(input_shape)
def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"type_vocab_size": self.type_vocab_size,
"hidden_size": self.hidden_size,
......@@ -141,15 +141,15 @@ class TFAlbertTokenTypeEmbeddings(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items()))
def call(self, token_type_ids):
def call(self, token_type_ids: tf.Tensor) -> tf.Tensor:
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=token_type_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(token_type_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=token_type_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(token_type_ids.shape.as_list() + [self.hidden_size])
return embeddings
......@@ -163,16 +163,16 @@ class TFAlbertPositionEmbeddings(tf.keras.layers.Layer):
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
......@@ -182,8 +182,8 @@ class TFAlbertPositionEmbeddings(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items()))
def call(self, position_ids):
input_shape = shape_list(tensor=position_ids)
def call(self, position_ids: tf.Tensor) -> tf.Tensor:
input_shape = shape_list(position_ids)
position_embeddings = self.position_embeddings[: input_shape[1], :]
return tf.broadcast_to(input=position_embeddings, shape=input_shape)
......@@ -218,7 +218,14 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False):
def call(
self,
input_ids: tf.Tensor,
position_ids: tf.Tensor,
token_type_ids: tf.Tensor,
inputs_embeds: tf.Tensor,
training: bool = False,
) -> tf.Tensor:
"""
Applies embedding based on inputs tensor.
......@@ -879,7 +886,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
def serving_output(self, output):
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1105,7 +1112,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output
def serving_output(self, output):
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1208,7 +1215,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1310,7 +1317,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1425,7 +1432,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
def serving_output(self, output):
def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1572,13 +1579,14 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
}
]
)
def serving(self, inputs):
output = self.call(inputs)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving
def serving(self, inputs: Dict[str, tf.Tensor]):
output = self.call(input_ids=inputs)
return self.serving_output(output)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
def serving_output(self, output):
def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......
......@@ -919,7 +919,7 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......
......@@ -17,6 +17,7 @@
"""
import warnings
from typing import Any, Dict
import tensorflow as tf
......@@ -76,16 +77,16 @@ class TFDistilBertWordEmbeddings(tf.keras.layers.Layer):
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape=input_shape)
super().build(input_shape)
def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
......@@ -95,14 +96,14 @@ class TFDistilBertWordEmbeddings(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids):
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
return embeddings
......@@ -116,16 +117,16 @@ class TFDistilBertPositionEmbeddings(tf.keras.layers.Layer):
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
......@@ -135,8 +136,8 @@ class TFDistilBertPositionEmbeddings(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items()))
def call(self, position_ids):
input_shape = shape_list(tensor=position_ids)
def call(self, position_ids: tf.Tensor) -> tf.Tensor:
input_shape = shape_list(position_ids)
position_embeddings = self.position_embeddings[: input_shape[1], :]
return tf.broadcast_to(input=position_embeddings, shape=input_shape)
......@@ -796,7 +797,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output
def serving_output(self, output):
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -897,7 +898,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -988,7 +989,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1131,7 +1132,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
return self.serving_output(output)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
def serving_output(self, output):
def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1238,7 +1239,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
def serving_output(self, output):
def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......
......@@ -16,7 +16,7 @@
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple
import tensorflow as tf
......@@ -83,16 +83,16 @@ class TFFunnelWordEmbeddings(tf.keras.layers.Layer):
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape=input_shape)
super().build(input_shape)
def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
......@@ -102,14 +102,14 @@ class TFFunnelWordEmbeddings(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids):
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
return embeddings
......@@ -1436,7 +1436,7 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output
def serving_output(self, output):
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1526,7 +1526,7 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1656,13 +1656,13 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
}
]
)
def serving(self, inputs):
output = self.call(inputs)
def serving(self, inputs: Dict[str, tf.Tensor]):
output = self.call(input_ids=inputs)
return self.serving_output(output)
return self.serving_output(output=output)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
def serving_output(self, output):
def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1755,7 +1755,7 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1860,7 +1860,7 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
def serving_output(self, output):
def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......
......@@ -16,7 +16,7 @@
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple
import tensorflow as tf
......@@ -424,16 +424,16 @@ class TFLongformerWordEmbeddings(tf.keras.layers.Layer):
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape=input_shape)
super().build(input_shape)
def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
......@@ -443,14 +443,14 @@ class TFLongformerWordEmbeddings(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids):
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
return embeddings
......@@ -464,16 +464,16 @@ class TFLongformerTokenTypeEmbeddings(tf.keras.layers.Layer):
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape=input_shape)
super().build(input_shape)
def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"type_vocab_size": self.type_vocab_size,
"hidden_size": self.hidden_size,
......@@ -483,15 +483,15 @@ class TFLongformerTokenTypeEmbeddings(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items()))
def call(self, token_type_ids):
def call(self, token_type_ids: tf.Tensor) -> tf.Tensor:
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=token_type_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(token_type_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=token_type_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(token_type_ids.shape.as_list() + [self.hidden_size])
return embeddings
......@@ -508,7 +508,7 @@ class TFLongformerPositionEmbeddings(tf.keras.layers.Layer):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
......@@ -527,10 +527,10 @@ class TFLongformerPositionEmbeddings(tf.keras.layers.Layer):
flat_position_ids = tf.reshape(tensor=position_ids, shape=[-1])
embeddings = tf.gather(params=self.position_embeddings, indices=flat_position_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=position_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(position_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=position_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(position_ids.shape.as_list() + [self.hidden_size])
return embeddings
......@@ -638,8 +638,8 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
tensor=input_ids, shape=(input_ids_shape[0] * input_ids_shape[1], input_ids_shape[2])
)
mask = tf.cast(x=tf.math.not_equal(x=input_ids, y=self.padding_idx), dtype=input_ids.dtype)
incremental_indices = tf.math.cumsum(x=mask, axis=1) * mask
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
incremental_indices = tf.math.cumsum(mask, axis=1) * mask
return incremental_indices + self.padding_idx
......@@ -689,34 +689,34 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
return final_embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Longformer
class TFLongformerIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
def __init__(self, config: LongformerConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abc,cd->abd",
output_shape=(None, config.intermediate_size),
bias_axes="d",
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(activation_string=config.hidden_act)
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states):
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Longformer
class TFLongformerOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
def __init__(self, config: LongformerConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.experimental.EinsumDense(
......@@ -729,7 +729,7 @@ class TFLongformerOutput(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states, input_tensor, training=False):
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
......@@ -737,23 +737,23 @@ class TFLongformerOutput(tf.keras.layers.Layer):
return hidden_states
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Longformer
class TFLongformerPooler(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
def __init__(self, config: LongformerConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size,
units=config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="dense",
)
def call(self, hidden_states):
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.dense(inputs=first_token_tensor)
return pooled_output
......
......@@ -18,7 +18,7 @@
import warnings
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple
import tensorflow as tf
......@@ -186,16 +186,16 @@ class TFLxmertWordEmbeddings(tf.keras.layers.Layer):
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape=input_shape)
super().build(input_shape)
def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
......@@ -205,14 +205,14 @@ class TFLxmertWordEmbeddings(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids):
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
return embeddings
......@@ -226,16 +226,16 @@ class TFLxmertTokenTypeEmbeddings(tf.keras.layers.Layer):
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape=input_shape)
super().build(input_shape)
def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"type_vocab_size": self.type_vocab_size,
"hidden_size": self.hidden_size,
......@@ -245,15 +245,15 @@ class TFLxmertTokenTypeEmbeddings(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items()))
def call(self, token_type_ids):
def call(self, token_type_ids: tf.Tensor) -> tf.Tensor:
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=token_type_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(token_type_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=token_type_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(token_type_ids.shape.as_list() + [self.hidden_size])
return embeddings
......@@ -267,16 +267,16 @@ class TFLxmertPositionEmbeddings(tf.keras.layers.Layer):
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
......@@ -286,8 +286,8 @@ class TFLxmertPositionEmbeddings(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items()))
def call(self, position_ids):
input_shape = shape_list(tensor=position_ids)
def call(self, position_ids: tf.Tensor) -> tf.Tensor:
input_shape = shape_list(position_ids)
position_embeddings = self.position_embeddings[: input_shape[1], :]
return tf.broadcast_to(input=position_embeddings, shape=input_shape)
......@@ -1132,11 +1132,13 @@ class TFLxmertPooler(tf.keras.layers.Layer):
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPredictionHeadTransform with Bert->Lxmert
class TFLxmertPredictionHeadTransform(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
def __init__(self, config: LxmertConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
units=config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
if isinstance(config.hidden_act, str):
......@@ -1146,17 +1148,17 @@ class TFLxmertPredictionHeadTransform(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
hidden_states = self.LayerNorm(inputs=hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMPredictionHead with Bert->Lxmert
class TFLxmertLMPredictionHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
def __init__(self, config: LxmertConfig, input_embeddings: tf.keras.layers.Layer, **kwargs):
super().__init__(**kwargs)
self.vocab_size = config.vocab_size
......@@ -1168,28 +1170,28 @@ class TFLxmertLMPredictionHead(tf.keras.layers.Layer):
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def get_output_embeddings(self):
def get_output_embeddings(self) -> tf.keras.layers.Layer:
return self.input_embeddings
def set_output_embeddings(self, value):
def set_output_embeddings(self, value: tf.Variable):
self.input_embeddings.weight = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
def get_bias(self) -> Dict[str, tf.Variable]:
return {"bias": self.bias}
def set_bias(self, value):
def set_bias(self, value: tf.Variable):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states):
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.transform(hidden_states=hidden_states)
seq_length = shape_list(tensor=hidden_states)[1]
seq_length = shape_list(hidden_states)[1]
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size])
hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.vocab_size])
......@@ -1200,13 +1202,13 @@ class TFLxmertLMPredictionHead(tf.keras.layers.Layer):
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMLMHead with Bert->Lxmert
class TFLxmertMLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
def __init__(self, config: LxmertConfig, input_embeddings: tf.keras.layers.Layer, **kwargs):
super().__init__(**kwargs)
self.predictions = TFLxmertLMPredictionHead(config, input_embeddings, name="predictions")
def call(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
def call(self, sequence_output: tf.Tensor) -> tf.Tensor:
prediction_scores = self.predictions(hidden_states=sequence_output)
return prediction_scores
......
......@@ -17,7 +17,7 @@
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple
import tensorflow as tf
......@@ -116,16 +116,16 @@ class TFMobileBertWordEmbeddings(tf.keras.layers.Layer):
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape=input_shape)
super().build(input_shape)
def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
......@@ -135,14 +135,14 @@ class TFMobileBertWordEmbeddings(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids):
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
return embeddings
......@@ -156,16 +156,16 @@ class TFMobileBertTokenTypeEmbeddings(tf.keras.layers.Layer):
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape=input_shape)
super().build(input_shape)
def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"type_vocab_size": self.type_vocab_size,
"hidden_size": self.hidden_size,
......@@ -175,15 +175,15 @@ class TFMobileBertTokenTypeEmbeddings(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items()))
def call(self, token_type_ids):
def call(self, token_type_ids: tf.Tensor) -> tf.Tensor:
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=token_type_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(token_type_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=token_type_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(token_type_ids.shape.as_list() + [self.hidden_size])
return embeddings
......@@ -197,16 +197,16 @@ class TFMobileBertPositionEmbeddings(tf.keras.layers.Layer):
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
......@@ -216,8 +216,8 @@ class TFMobileBertPositionEmbeddings(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items()))
def call(self, position_ids):
input_shape = shape_list(tensor=position_ids)
def call(self, position_ids: tf.Tensor) -> tf.Tensor:
input_shape = shape_list(position_ids)
position_embeddings = self.position_embeddings[: input_shape[1], :]
return tf.broadcast_to(input=position_embeddings, shape=input_shape)
......@@ -1085,7 +1085,7 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel):
return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
def serving_output(self, output):
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1299,7 +1299,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output
def serving_output(self, output):
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1413,7 +1413,7 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForNextSentencePrediction.serving_output
def serving_output(self, output):
def serving_output(self, output: TFNextSentencePredictorOutput) -> TFNextSentencePredictorOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1522,7 +1522,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1643,7 +1643,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
def serving_output(self, output):
def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1796,13 +1796,14 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
}
]
)
def serving(self, inputs):
output = self.call(inputs)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving
def serving(self, inputs: Dict[str, tf.Tensor]):
output = self.call(input_ids=inputs)
return self.serving_output(output)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
def serving_output(self, output):
def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1911,7 +1912,7 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......
......@@ -18,6 +18,7 @@
import math
import warnings
from typing import Any, Dict
import tensorflow as tf
......@@ -95,16 +96,16 @@ class TFMPNetWordEmbeddings(tf.keras.layers.Layer):
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape=input_shape)
super().build(input_shape)
def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
......@@ -114,14 +115,14 @@ class TFMPNetWordEmbeddings(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids):
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
return embeddings
......@@ -139,7 +140,7 @@ class TFMPNetPositionEmbeddings(tf.keras.layers.Layer):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
......@@ -158,10 +159,10 @@ class TFMPNetPositionEmbeddings(tf.keras.layers.Layer):
flat_position_ids = tf.reshape(tensor=position_ids, shape=[-1])
embeddings = tf.gather(params=self.position_embeddings, indices=flat_position_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=position_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(position_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=position_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(position_ids.shape.as_list() + [self.hidden_size])
return embeddings
......@@ -207,8 +208,8 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer):
tensor=input_ids, shape=(input_ids_shape[0] * input_ids_shape[1], input_ids_shape[2])
)
mask = tf.cast(x=tf.math.not_equal(x=input_ids, y=self.padding_idx), dtype=input_ids.dtype)
incremental_indices = tf.math.cumsum(x=mask, axis=1) * mask
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
incremental_indices = tf.math.cumsum(mask, axis=1) * mask
return incremental_indices + self.padding_idx
......@@ -253,23 +254,23 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer):
return final_embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->MPNet
class TFMPNetPooler(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
def __init__(self, config: MPNetConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size,
units=config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="dense",
)
def call(self, hidden_states):
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.dense(inputs=first_token_tensor)
return pooled_output
......@@ -291,28 +292,28 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer):
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name="q",
)
self.k = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name="k",
)
self.v = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name="v",
)
self.o = tf.keras.layers.experimental.EinsumDense(
equation="abcd,cde->abe",
output_shape=(None, self.all_head_size),
bias_axes="e",
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name="o",
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
......@@ -322,8 +323,8 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer):
k = self.k(hidden_states)
v = self.v(hidden_states)
dk = tf.cast(x=self.attention_head_size, dtype=q.dtype)
q = tf.multiply(x=q, y=tf.math.rsqrt(x=dk))
dk = tf.cast(self.attention_head_size, dtype=q.dtype)
q = tf.multiply(q, y=tf.math.rsqrt(dk))
attention_scores = tf.einsum("aecd,abcd->acbe", k, q)
# Apply relative position embedding (precomputed in MPNetEncoder) if provided.
......@@ -368,34 +369,34 @@ class TFMPNetAttention(tf.keras.layers.Layer):
return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->MPNet
class TFMPNetIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
def __init__(self, config: MPNetConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abc,cd->abd",
output_shape=(None, config.intermediate_size),
bias_axes="d",
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(activation_string=config.hidden_act)
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states):
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->MPNet
class TFMPNetOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
def __init__(self, config: MPNetConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.experimental.EinsumDense(
......@@ -408,7 +409,7 @@ class TFMPNetOutput(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states, input_tensor, training=False):
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
......@@ -563,11 +564,11 @@ class TFMPNetMainLayer(tf.keras.layers.Layer):
self.embeddings = TFMPNetEmbeddings(config, name="embeddings")
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings
def get_input_embeddings(self):
def get_input_embeddings(self) -> tf.keras.layers.Layer:
return self.embeddings.word_embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
def set_input_embeddings(self, value):
def set_input_embeddings(self, value: tf.Variable):
self.embeddings.word_embeddings.weight = value
self.embeddings.word_embeddings.vocab_size = shape_list(value)[0]
......@@ -820,7 +821,7 @@ class TFMPNetModel(TFMPNetPreTrainedModel):
return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
def serving_output(self, output):
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -973,7 +974,7 @@ class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss):
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output
def serving_output(self, output):
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1095,7 +1096,7 @@ class TFMPNetForSequenceClassification(TFMPNetPreTrainedModel, TFSequenceClassif
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1233,7 +1234,7 @@ class TFMPNetForMultipleChoice(TFMPNetPreTrainedModel, TFMultipleChoiceLoss):
return self.serving_output(output)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
def serving_output(self, output):
def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1333,7 +1334,7 @@ class TFMPNetForTokenClassification(TFMPNetPreTrainedModel, TFTokenClassificatio
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1446,7 +1447,7 @@ class TFMPNetForQuestionAnswering(TFMPNetPreTrainedModel, TFQuestionAnsweringLos
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
def serving_output(self, output):
def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......
......@@ -663,7 +663,7 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output
def serving_output(self, output):
def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -965,7 +965,7 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......
......@@ -19,7 +19,7 @@
import itertools
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple
import numpy as np
import tensorflow as tf
......@@ -1019,7 +1019,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1180,13 +1180,14 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
}
]
)
def serving(self, inputs):
output = self.call(inputs)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving
def serving(self, inputs: Dict[str, tf.Tensor]):
output = self.call(input_ids=inputs)
return self.serving_output(output)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
def serving_output(self, output):
def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1294,7 +1295,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......@@ -1413,7 +1414,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
def serving_output(self, output):
def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment