Unverified Commit 31b0560a authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Add AMP for Albert (#10141)

parent 6fc940ed
...@@ -148,21 +148,21 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -148,21 +148,21 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
self.weight = self.add_weight( self.weight = self.add_weight(
name="weight", name="weight",
shape=[self.vocab_size, self.hidden_size], shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
with tf.name_scope("token_type_embeddings"): with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight( self.token_type_embeddings = self.add_weight(
name="embeddings", name="embeddings",
shape=[self.type_vocab_size, self.hidden_size], shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
with tf.name_scope("position_embeddings"): with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight( self.position_embeddings = self.add_weight(
name="embeddings", name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size], shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
super().build(input_shape) super().build(input_shape)
...@@ -253,8 +253,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -253,8 +253,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw attention scores.
# attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k) # (batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
...@@ -1009,7 +1008,8 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss): ...@@ -1009,7 +1008,8 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
total_loss = self.compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score)) total_loss = self.compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score))
if not inputs["return_dict"]: if not inputs["return_dict"]:
return (prediction_scores, seq_relationship_score) + outputs[2:] output = (prediction_scores, seq_relationship_score) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return TFBertForPreTrainingOutput( return TFBertForPreTrainingOutput(
loss=total_loss, loss=total_loss,
...@@ -1598,7 +1598,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1598,7 +1598,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
} }
] ]
) )
def serving(self, inputs: Dict[str, tf.Tensor]): def serving(self, inputs: Dict[str, tf.Tensor]) -> TFMultipleChoiceModelOutput:
output = self.call(input_ids=inputs) output = self.call(input_ids=inputs)
return self.serving_output(output) return self.serving_output(output)
......
...@@ -62,11 +62,11 @@ TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -62,11 +62,11 @@ TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.models.albert.modeling_tf_albert.TFAlbertEmbeddings # Copied from transformers.models.albert.modeling_tf_albert.TFAlbertEmbeddings with Albert->ConvBert
class TFConvBertEmbeddings(tf.keras.layers.Layer): class TFConvBertEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config, **kwargs): def __init__(self, config: ConvBertConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -83,21 +83,21 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer): ...@@ -83,21 +83,21 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer):
self.weight = self.add_weight( self.weight = self.add_weight(
name="weight", name="weight",
shape=[self.vocab_size, self.embedding_size], shape=[self.vocab_size, self.embedding_size],
initializer=get_initializer(initializer_range=self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
with tf.name_scope("token_type_embeddings"): with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight( self.token_type_embeddings = self.add_weight(
name="embeddings", name="embeddings",
shape=[self.type_vocab_size, self.embedding_size], shape=[self.type_vocab_size, self.embedding_size],
initializer=get_initializer(initializer_range=self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
with tf.name_scope("position_embeddings"): with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight( self.position_embeddings = self.add_weight(
name="embeddings", name="embeddings",
shape=[self.max_position_embeddings, self.embedding_size], shape=[self.max_position_embeddings, self.embedding_size],
initializer=get_initializer(initializer_range=self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
super().build(input_shape) super().build(input_shape)
......
...@@ -121,8 +121,7 @@ class TFElectraSelfAttention(tf.keras.layers.Layer): ...@@ -121,8 +121,7 @@ class TFElectraSelfAttention(tf.keras.layers.Layer):
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw attention scores.
# attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k) # (batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
...@@ -353,7 +352,7 @@ class TFElectraPooler(tf.keras.layers.Layer): ...@@ -353,7 +352,7 @@ class TFElectraPooler(tf.keras.layers.Layer):
class TFElectraEmbeddings(tf.keras.layers.Layer): class TFElectraEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config, **kwargs): def __init__(self, config: ElectraConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -370,21 +369,21 @@ class TFElectraEmbeddings(tf.keras.layers.Layer): ...@@ -370,21 +369,21 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
self.weight = self.add_weight( self.weight = self.add_weight(
name="weight", name="weight",
shape=[self.vocab_size, self.embedding_size], shape=[self.vocab_size, self.embedding_size],
initializer=get_initializer(initializer_range=self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
with tf.name_scope("token_type_embeddings"): with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight( self.token_type_embeddings = self.add_weight(
name="embeddings", name="embeddings",
shape=[self.type_vocab_size, self.embedding_size], shape=[self.type_vocab_size, self.embedding_size],
initializer=get_initializer(initializer_range=self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
with tf.name_scope("position_embeddings"): with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight( self.position_embeddings = self.add_weight(
name="embeddings", name="embeddings",
shape=[self.max_position_embeddings, self.embedding_size], shape=[self.max_position_embeddings, self.embedding_size],
initializer=get_initializer(initializer_range=self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
super().build(input_shape) super().build(input_shape)
......
...@@ -491,21 +491,21 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer): ...@@ -491,21 +491,21 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
self.weight = self.add_weight( self.weight = self.add_weight(
name="weight", name="weight",
shape=[self.vocab_size, self.hidden_size], shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
with tf.name_scope("token_type_embeddings"): with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight( self.token_type_embeddings = self.add_weight(
name="embeddings", name="embeddings",
shape=[self.type_vocab_size, self.hidden_size], shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
with tf.name_scope("position_embeddings"): with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight( self.position_embeddings = self.add_weight(
name="embeddings", name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size], shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
super().build(input_shape) super().build(input_shape)
......
...@@ -92,21 +92,21 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer): ...@@ -92,21 +92,21 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
self.weight = self.add_weight( self.weight = self.add_weight(
name="weight", name="weight",
shape=[self.vocab_size, self.hidden_size], shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
with tf.name_scope("token_type_embeddings"): with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight( self.token_type_embeddings = self.add_weight(
name="embeddings", name="embeddings",
shape=[self.type_vocab_size, self.hidden_size], shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
with tf.name_scope("position_embeddings"): with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight( self.position_embeddings = self.add_weight(
name="embeddings", name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size], shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
super().build(input_shape) super().build(input_shape)
...@@ -232,8 +232,7 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer): ...@@ -232,8 +232,7 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw attention scores.
# attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k) # (batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
......
...@@ -90,21 +90,21 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer): ...@@ -90,21 +90,21 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
self.weight = self.add_weight( self.weight = self.add_weight(
name="weight", name="weight",
shape=[self.vocab_size, self.hidden_size], shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
with tf.name_scope("token_type_embeddings"): with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight( self.token_type_embeddings = self.add_weight(
name="embeddings", name="embeddings",
shape=[self.type_vocab_size, self.hidden_size], shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
with tf.name_scope("position_embeddings"): with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight( self.position_embeddings = self.add_weight(
name="embeddings", name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size], shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range), initializer=get_initializer(self.initializer_range),
) )
super().build(input_shape) super().build(input_shape)
...@@ -197,8 +197,7 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer) ...@@ -197,8 +197,7 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw attention scores.
# attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k) # (batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
...@@ -1247,7 +1246,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c ...@@ -1247,7 +1246,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"), "token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
}]) }])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving # Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving
def serving(self, inputs: Dict[str, tf.Tensor]): def serving(self, inputs: Dict[str, tf.Tensor]) -> TFMultipleChoiceModelOutput:
output = self.call(input_ids=inputs) output = self.call(input_ids=inputs)
return self.serving_output(output) return self.serving_output(output)
......
...@@ -26,6 +26,7 @@ from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor ...@@ -26,6 +26,7 @@ from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from transformers import TF_MODEL_FOR_PRETRAINING_MAPPING
from transformers.models.albert.modeling_tf_albert import ( from transformers.models.albert.modeling_tf_albert import (
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFAlbertForMaskedLM, TFAlbertForMaskedLM,
...@@ -243,6 +244,16 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -243,6 +244,16 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values():
inputs_dict["sentence_order_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
return inputs_dict
def setUp(self): def setUp(self):
self.model_tester = TFAlbertModelTester(self) self.model_tester = TFAlbertModelTester(self)
self.config_tester = ConfigTester(self, config_class=AlbertConfig, hidden_size=37) self.config_tester = ConfigTester(self, config_class=AlbertConfig, hidden_size=37)
...@@ -295,10 +306,6 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -295,10 +306,6 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
name = model.get_bias() name = model.get_bias()
assert name is None assert name is None
def test_mixed_precision(self):
# TODO JP: Make ALBERT float16 compliant
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment