"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "09013efdf158064c0e253c80dd791aa62b9c48bf"
Unverified Commit fb56bf25 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Making TF MobileBert model compliant with AMP (#10259)

* Fix AMP

* Trigger CI

* Rework cast
parent 2fc6284f
...@@ -251,11 +251,12 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer): ...@@ -251,11 +251,12 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer):
attention_scores = tf.matmul( attention_scores = tf.matmul(
query_layer, key_layer, transpose_b=True query_layer, key_layer, transpose_b=True
) # (batch size, num_heads, seq_len_q, seq_len_k) ) # (batch size, num_heads, seq_len_q, seq_len_k)
dk = tf.cast(shape_list(key_layer)[-1], tf.float32) # scale attention_scores dk = tf.cast(shape_list(key_layer)[-1], dtype=attention_scores.dtype) # scale attention_scores
attention_scores = attention_scores / tf.math.sqrt(dk) attention_scores = attention_scores / tf.math.sqrt(dk)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function) # Apply the attention mask is (precomputed for all layers in TFMobileBertModel call() function)
attention_mask = tf.cast(attention_mask, dtype=attention_scores.dtype)
attention_scores = attention_scores + attention_mask attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
...@@ -726,6 +727,14 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): ...@@ -726,6 +727,14 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
if inputs["token_type_ids"] is None: if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0) inputs["token_type_ids"] = tf.fill(input_shape, 0)
embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
...@@ -738,9 +747,10 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): ...@@ -738,9 +747,10 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
# positions we want to attend and -10000.0 for masked positions. # positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32) one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
...@@ -752,13 +762,6 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): ...@@ -752,13 +762,6 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
else: else:
inputs["head_mask"] = [None] * self.num_hidden_layers inputs["head_mask"] = [None] * self.num_hidden_layers
embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
extended_attention_mask, extended_attention_mask,
......
...@@ -310,10 +310,6 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -310,10 +310,6 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI # This test is too long (>30sec) and makes fail the CI
pass pass
def test_mixed_precision(self):
# TODO JP: Make MobileBert float16 compliant
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
# for model_name in TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: # for model_name in TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment