Unverified Commit 14ed3b97 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix AMP (#10216)

parent bdf1669e
......@@ -144,7 +144,7 @@ class TFFunnelAttentionStructure:
# attention_mask and token_type_ids have shape batch_size x seq_len
self.pooling_mult = 1
self.seq_len = seq_len = shape_list(inputs_embeds)[1]
position_embeds = self.get_position_embeds(seq_len, dtype=inputs_embeds.dtype, training=training)
position_embeds = self.get_position_embeds(seq_len, training=training)
token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
cls_mask = (
tf.pad(tf.ones([seq_len - 1, seq_len - 1], dtype=inputs_embeds.dtype), [[1, 0], [1, 0]])
......@@ -161,7 +161,7 @@ class TFFunnelAttentionStructure:
cls_mat = tf.logical_or(tf.expand_dims(cls_ids, -1), tf.expand_dims(cls_ids, -2))
return tf.logical_or(cls_mat, token_type_mat)
def get_position_embeds(self, seq_len, dtype=tf.float32, training=False):
def get_position_embeds(self, seq_len, training=False):
"""
Create and cache inputs related to relative position encoding. Those are very different depending on whether we
are using the factorized or the relative shift attention:
......@@ -177,8 +177,8 @@ class TFFunnelAttentionStructure:
if self.attention_type == "factorized":
# Notations from the paper, appending A.2.2, final formula.
# We need to create and return the matrices phi, psi, pi and omega.
pos_seq = tf.range(0, seq_len, 1.0, dtype=dtype)
freq_seq = tf.range(0, self.d_model // 2, 1.0, dtype=dtype)
pos_seq = tf.range(0, seq_len, 1.0)
freq_seq = tf.range(0, self.d_model // 2, 1.0)
inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2)))
sinusoid = tf.einsum("i,d->id", pos_seq, inv_freq)
......@@ -195,17 +195,17 @@ class TFFunnelAttentionStructure:
else:
# Notations from the paper, appending A.2.1, final formula.
# We need to create and return all the possible vectors R for all blocks and shifts.
freq_seq = tf.range(0, self.d_model // 2, 1.0, dtype=dtype)
freq_seq = tf.range(0, self.d_model // 2, 1.0)
inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2)))
# Maximum relative positions for the first input
rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0, dtype=dtype)
rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0)
zero_offset = seq_len * tf.constant(2)
sinusoid = tf.einsum("i,d->id", rel_pos_id, inv_freq)
sin_embed = self.sin_dropout(tf.sin(sinusoid), training=training)
cos_embed = self.cos_dropout(tf.cos(sinusoid), training=training)
pos_embed = tf.concat([sin_embed, cos_embed], axis=-1)
pos = tf.range(0, seq_len, dtype=dtype)
pos = tf.range(0, seq_len)
pooled_pos = pos
position_embeds_list = []
for block_index in range(0, self.num_blocks):
......@@ -258,7 +258,7 @@ class TFFunnelAttentionStructure:
else:
return pos_id[::2]
def relative_pos(self, pos, stride, pooled_pos=None, shift=1.0):
def relative_pos(self, pos, stride, pooled_pos=None, shift=1):
"""
Build the relative positional vector between `pos` and `pooled_pos`.
"""
......@@ -266,7 +266,7 @@ class TFFunnelAttentionStructure:
pooled_pos = pos
ref_point = pooled_pos[0] - pos[0]
num_remove = shift * tf.cast(shape_list(pooled_pos)[0], dtype=ref_point.dtype)
num_remove = shift * shape_list(pooled_pos)[0]
max_dist = ref_point + num_remove * stride
min_dist = pooled_pos[0] - pos[-1]
......@@ -522,17 +522,13 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
# merge attention scores
attn_score = content_score + positional_attn + token_type_attn
# precision safe in case of mixed precision training
dtype = attn_score.dtype
if dtype != tf.float32:
attn_score = tf.cast(attn_score, tf.float32)
# perform masking
if attention_mask is not None:
attn_score = attn_score - INF * (1 - tf.cast(attention_mask[:, None, None], tf.float32))
attention_mask = tf.cast(attention_mask, dtype=attn_score.dtype)
attn_score = attn_score - (INF * (1 - attention_mask[:, None, None]))
# attention probability
attn_prob = tf.nn.softmax(attn_score, axis=-1)
if dtype != tf.float32:
attn_prob = tf.cast(attn_prob, dtype)
attn_prob = self.attention_dropout(attn_prob, training=training)
# attention output, shape batch_size x seq_len x n_head x d_head
......
......@@ -372,10 +372,6 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def test_mixed_precision(self):
# TODO JP: Make Funnel float16 compliant
pass
@require_tf
class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
......@@ -407,7 +403,3 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
def test_mixed_precision(self):
# TODO JP: Make Funnel float16 compliant
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment