Unverified Commit 86caeb76 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix XLA and AMP (#10262)

parent 3d72d47f
......@@ -169,14 +169,18 @@ class TFT5Attention(tf.keras.layers.Layer):
self.o = tf.keras.layers.Dense(self.d_model, use_bias=False, name="o")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
if self.has_relative_attention_bias:
self.relative_attention_bias = tf.keras.layers.Embedding(
self.relative_attention_num_buckets,
self.n_heads,
name="relative_attention_bias",
)
self.pruned_heads = set()
def build(self, input_shape):
if self.has_relative_attention_bias:
with tf.name_scope("relative_attention_bias"):
self.relative_attention_bias = self.add_weight(
name="embeddings",
shape=[self.relative_attention_num_buckets, self.n_heads],
)
return super().build(input_shape)
def prune_heads(self, heads):
raise NotImplementedError
......@@ -206,18 +210,20 @@ class TFT5Attention(tf.keras.layers.Layer):
# n = -relative_position
if bidirectional:
num_buckets //= 2
relative_buckets += tf.dtypes.cast(tf.math.greater(relative_position, 0), tf.int32) * num_buckets
relative_buckets += (
tf.cast(tf.math.greater(relative_position, 0), dtype=relative_position.dtype) * num_buckets
)
relative_position = tf.math.abs(relative_position)
else:
relative_position = -tf.math.minimum(relative_position, 0)
# now n is in the range [0, inf)
max_exact = num_buckets // 2
is_small = tf.math.less(relative_position, max_exact)
relative_position_if_large = max_exact + tf.dtypes.cast(
tf.math.log(tf.dtypes.cast(relative_position, tf.float32) / max_exact)
relative_position_if_large = max_exact + tf.cast(
tf.math.log(relative_position / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact),
tf.int32,
dtype=relative_position.dtype,
)
relative_position_if_large = tf.math.minimum(relative_position_if_large, num_buckets - 1)
relative_buckets += tf.where(is_small, relative_position, relative_position_if_large)
......@@ -233,7 +239,9 @@ class TFT5Attention(tf.keras.layers.Layer):
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = tf.gather(
self.relative_attention_bias, relative_position_bucket
) # shape (query_length, key_length, num_heads)
values = tf.expand_dims(
tf.transpose(values, [2, 0, 1]), axis=0
) # shape (1, num_heads, query_length, key_length)
......@@ -326,7 +334,7 @@ class TFT5Attention(tf.keras.layers.Layer):
if position_bias is None:
if not self.has_relative_attention_bias:
position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length), dtype=tf.float32)
position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length))
else:
position_bias = self.compute_bias(real_seq_length, key_length)
......@@ -336,6 +344,7 @@ class TFT5Attention(tf.keras.layers.Layer):
position_bias = position_bias[:, :, -seq_length:, :]
if mask is not None:
position_bias = tf.cast(position_bias, dtype=mask.dtype)
position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length)
scores += position_bias
......@@ -662,7 +671,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=tf.float32)
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=inputs["inputs_embeds"].dtype)
num_dims_attention_mask = len(shape_list(inputs["attention_mask"]))
if num_dims_attention_mask == 3:
extended_attention_mask = inputs["attention_mask"][:, None, :, :]
......@@ -676,7 +685,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
seq_ids[None, :, None],
)
causal_mask = tf.cast(causal_mask, dtype=tf.float32)
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype)
extended_attention_mask = causal_mask[:, None, :, :] * inputs["attention_mask"][:, None, None, :]
if inputs["past_key_values"][0] is not None:
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
......@@ -700,7 +709,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs["encoder_attention_mask"] = tf.cast(inputs["encoder_attention_mask"], dtype=tf.float32)
inputs["encoder_attention_mask"] = tf.cast(
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype
)
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
......@@ -868,8 +879,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
decoder_start_token_id is not None
), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
shifted_input_ids = tf.cast(input_ids, tf.int32)
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
......@@ -880,7 +890,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
)
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
......
......@@ -305,14 +305,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def test_mixed_precision(self):
# TODO JP: Make T5 float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make T5 XLA compliant
pass
@slow
def test_model_from_pretrained(self):
model = TFT5Model.from_pretrained("t5-small")
......@@ -442,14 +434,6 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
def test_train_pipeline_custom_model(self):
pass
def test_mixed_precision(self):
# TODO JP: Make T5 float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make T5 XLA compliant
pass
@require_tf
@require_sentencepiece
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment