Unverified Commit 86caeb76 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix XLA and AMP (#10262)

parent 3d72d47f
...@@ -169,13 +169,17 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -169,13 +169,17 @@ class TFT5Attention(tf.keras.layers.Layer):
self.o = tf.keras.layers.Dense(self.d_model, use_bias=False, name="o") self.o = tf.keras.layers.Dense(self.d_model, use_bias=False, name="o")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
self.pruned_heads = set()
def build(self, input_shape):
if self.has_relative_attention_bias: if self.has_relative_attention_bias:
self.relative_attention_bias = tf.keras.layers.Embedding( with tf.name_scope("relative_attention_bias"):
self.relative_attention_num_buckets, self.relative_attention_bias = self.add_weight(
self.n_heads, name="embeddings",
name="relative_attention_bias", shape=[self.relative_attention_num_buckets, self.n_heads],
) )
self.pruned_heads = set()
return super().build(input_shape)
def prune_heads(self, heads): def prune_heads(self, heads):
raise NotImplementedError raise NotImplementedError
...@@ -206,18 +210,20 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -206,18 +210,20 @@ class TFT5Attention(tf.keras.layers.Layer):
# n = -relative_position # n = -relative_position
if bidirectional: if bidirectional:
num_buckets //= 2 num_buckets //= 2
relative_buckets += tf.dtypes.cast(tf.math.greater(relative_position, 0), tf.int32) * num_buckets relative_buckets += (
tf.cast(tf.math.greater(relative_position, 0), dtype=relative_position.dtype) * num_buckets
)
relative_position = tf.math.abs(relative_position) relative_position = tf.math.abs(relative_position)
else: else:
relative_position = -tf.math.minimum(relative_position, 0) relative_position = -tf.math.minimum(relative_position, 0)
# now n is in the range [0, inf) # now n is in the range [0, inf)
max_exact = num_buckets // 2 max_exact = num_buckets // 2
is_small = tf.math.less(relative_position, max_exact) is_small = tf.math.less(relative_position, max_exact)
relative_position_if_large = max_exact + tf.dtypes.cast( relative_position_if_large = max_exact + tf.cast(
tf.math.log(tf.dtypes.cast(relative_position, tf.float32) / max_exact) tf.math.log(relative_position / max_exact)
/ math.log(max_distance / max_exact) / math.log(max_distance / max_exact)
* (num_buckets - max_exact), * (num_buckets - max_exact),
tf.int32, dtype=relative_position.dtype,
) )
relative_position_if_large = tf.math.minimum(relative_position_if_large, num_buckets - 1) relative_position_if_large = tf.math.minimum(relative_position_if_large, num_buckets - 1)
relative_buckets += tf.where(is_small, relative_position, relative_position_if_large) relative_buckets += tf.where(is_small, relative_position, relative_position_if_large)
...@@ -233,7 +239,9 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -233,7 +239,9 @@ class TFT5Attention(tf.keras.layers.Layer):
bidirectional=(not self.is_decoder), bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets, num_buckets=self.relative_attention_num_buckets,
) )
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) values = tf.gather(
self.relative_attention_bias, relative_position_bucket
) # shape (query_length, key_length, num_heads)
values = tf.expand_dims( values = tf.expand_dims(
tf.transpose(values, [2, 0, 1]), axis=0 tf.transpose(values, [2, 0, 1]), axis=0
) # shape (1, num_heads, query_length, key_length) ) # shape (1, num_heads, query_length, key_length)
...@@ -326,7 +334,7 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -326,7 +334,7 @@ class TFT5Attention(tf.keras.layers.Layer):
if position_bias is None: if position_bias is None:
if not self.has_relative_attention_bias: if not self.has_relative_attention_bias:
position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length), dtype=tf.float32) position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length))
else: else:
position_bias = self.compute_bias(real_seq_length, key_length) position_bias = self.compute_bias(real_seq_length, key_length)
...@@ -336,6 +344,7 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -336,6 +344,7 @@ class TFT5Attention(tf.keras.layers.Layer):
position_bias = position_bias[:, :, -seq_length:, :] position_bias = position_bias[:, :, -seq_length:, :]
if mask is not None: if mask is not None:
position_bias = tf.cast(position_bias, dtype=mask.dtype)
position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length) position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length)
scores += position_bias scores += position_bias
...@@ -662,7 +671,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -662,7 +671,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=tf.float32) inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=inputs["inputs_embeds"].dtype)
num_dims_attention_mask = len(shape_list(inputs["attention_mask"])) num_dims_attention_mask = len(shape_list(inputs["attention_mask"]))
if num_dims_attention_mask == 3: if num_dims_attention_mask == 3:
extended_attention_mask = inputs["attention_mask"][:, None, :, :] extended_attention_mask = inputs["attention_mask"][:, None, :, :]
...@@ -676,7 +685,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -676,7 +685,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
seq_ids[None, :, None], seq_ids[None, :, None],
) )
causal_mask = tf.cast(causal_mask, dtype=tf.float32) causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype)
extended_attention_mask = causal_mask[:, None, :, :] * inputs["attention_mask"][:, None, None, :] extended_attention_mask = causal_mask[:, None, :, :] * inputs["attention_mask"][:, None, None, :]
if inputs["past_key_values"][0] is not None: if inputs["past_key_values"][0] is not None:
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
...@@ -700,7 +709,9 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -700,7 +709,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# If a 2D ou 3D attention mask is provided for the cross-attention # If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs["encoder_attention_mask"] = tf.cast(inputs["encoder_attention_mask"], dtype=tf.float32) inputs["encoder_attention_mask"] = tf.cast(
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype
)
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"])) num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
if num_dims_encoder_attention_mask == 3: if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :] encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
...@@ -868,8 +879,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel): ...@@ -868,8 +879,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
decoder_start_token_id is not None decoder_start_token_id is not None
), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information" ), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
shifted_input_ids = tf.cast(input_ids, tf.int32) shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
...@@ -880,7 +890,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel): ...@@ -880,7 +890,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
) )
# "Verify that `labels` has only positive values and -100" # "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32)) assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
# Make sure the assertion op is called by wrapping the result in an identity no-op # Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]): with tf.control_dependencies([assert_gte0]):
......
...@@ -305,14 +305,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -305,14 +305,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI # This test is too long (>30sec) and makes fail the CI
pass pass
def test_mixed_precision(self):
# TODO JP: Make T5 float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make T5 XLA compliant
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
model = TFT5Model.from_pretrained("t5-small") model = TFT5Model.from_pretrained("t5-small")
...@@ -442,14 +434,6 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -442,14 +434,6 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
def test_train_pipeline_custom_model(self): def test_train_pipeline_custom_model(self):
pass pass
def test_mixed_precision(self):
# TODO JP: Make T5 float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make T5 XLA compliant
pass
@require_tf @require_tf
@require_sentencepiece @require_sentencepiece
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment