"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "da3c79b245afcce88f5db79ada10bf5b7c200ab1"
Unverified Commit 14ed3b97 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix AMP (#10216)

parent bdf1669e
...@@ -144,7 +144,7 @@ class TFFunnelAttentionStructure: ...@@ -144,7 +144,7 @@ class TFFunnelAttentionStructure:
# attention_mask and token_type_ids have shape batch_size x seq_len # attention_mask and token_type_ids have shape batch_size x seq_len
self.pooling_mult = 1 self.pooling_mult = 1
self.seq_len = seq_len = shape_list(inputs_embeds)[1] self.seq_len = seq_len = shape_list(inputs_embeds)[1]
position_embeds = self.get_position_embeds(seq_len, dtype=inputs_embeds.dtype, training=training) position_embeds = self.get_position_embeds(seq_len, training=training)
token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
cls_mask = ( cls_mask = (
tf.pad(tf.ones([seq_len - 1, seq_len - 1], dtype=inputs_embeds.dtype), [[1, 0], [1, 0]]) tf.pad(tf.ones([seq_len - 1, seq_len - 1], dtype=inputs_embeds.dtype), [[1, 0], [1, 0]])
...@@ -161,7 +161,7 @@ class TFFunnelAttentionStructure: ...@@ -161,7 +161,7 @@ class TFFunnelAttentionStructure:
cls_mat = tf.logical_or(tf.expand_dims(cls_ids, -1), tf.expand_dims(cls_ids, -2)) cls_mat = tf.logical_or(tf.expand_dims(cls_ids, -1), tf.expand_dims(cls_ids, -2))
return tf.logical_or(cls_mat, token_type_mat) return tf.logical_or(cls_mat, token_type_mat)
def get_position_embeds(self, seq_len, dtype=tf.float32, training=False): def get_position_embeds(self, seq_len, training=False):
""" """
Create and cache inputs related to relative position encoding. Those are very different depending on whether we Create and cache inputs related to relative position encoding. Those are very different depending on whether we
are using the factorized or the relative shift attention: are using the factorized or the relative shift attention:
...@@ -177,8 +177,8 @@ class TFFunnelAttentionStructure: ...@@ -177,8 +177,8 @@ class TFFunnelAttentionStructure:
if self.attention_type == "factorized": if self.attention_type == "factorized":
# Notations from the paper, appending A.2.2, final formula. # Notations from the paper, appending A.2.2, final formula.
# We need to create and return the matrices phi, psi, pi and omega. # We need to create and return the matrices phi, psi, pi and omega.
pos_seq = tf.range(0, seq_len, 1.0, dtype=dtype) pos_seq = tf.range(0, seq_len, 1.0)
freq_seq = tf.range(0, self.d_model // 2, 1.0, dtype=dtype) freq_seq = tf.range(0, self.d_model // 2, 1.0)
inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2))) inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2)))
sinusoid = tf.einsum("i,d->id", pos_seq, inv_freq) sinusoid = tf.einsum("i,d->id", pos_seq, inv_freq)
...@@ -195,17 +195,17 @@ class TFFunnelAttentionStructure: ...@@ -195,17 +195,17 @@ class TFFunnelAttentionStructure:
else: else:
# Notations from the paper, appending A.2.1, final formula. # Notations from the paper, appending A.2.1, final formula.
# We need to create and return all the possible vectors R for all blocks and shifts. # We need to create and return all the possible vectors R for all blocks and shifts.
freq_seq = tf.range(0, self.d_model // 2, 1.0, dtype=dtype) freq_seq = tf.range(0, self.d_model // 2, 1.0)
inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2))) inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2)))
# Maximum relative positions for the first input # Maximum relative positions for the first input
rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0, dtype=dtype) rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0)
zero_offset = seq_len * tf.constant(2) zero_offset = seq_len * tf.constant(2)
sinusoid = tf.einsum("i,d->id", rel_pos_id, inv_freq) sinusoid = tf.einsum("i,d->id", rel_pos_id, inv_freq)
sin_embed = self.sin_dropout(tf.sin(sinusoid), training=training) sin_embed = self.sin_dropout(tf.sin(sinusoid), training=training)
cos_embed = self.cos_dropout(tf.cos(sinusoid), training=training) cos_embed = self.cos_dropout(tf.cos(sinusoid), training=training)
pos_embed = tf.concat([sin_embed, cos_embed], axis=-1) pos_embed = tf.concat([sin_embed, cos_embed], axis=-1)
pos = tf.range(0, seq_len, dtype=dtype) pos = tf.range(0, seq_len)
pooled_pos = pos pooled_pos = pos
position_embeds_list = [] position_embeds_list = []
for block_index in range(0, self.num_blocks): for block_index in range(0, self.num_blocks):
...@@ -258,7 +258,7 @@ class TFFunnelAttentionStructure: ...@@ -258,7 +258,7 @@ class TFFunnelAttentionStructure:
else: else:
return pos_id[::2] return pos_id[::2]
def relative_pos(self, pos, stride, pooled_pos=None, shift=1.0): def relative_pos(self, pos, stride, pooled_pos=None, shift=1):
""" """
Build the relative positional vector between `pos` and `pooled_pos`. Build the relative positional vector between `pos` and `pooled_pos`.
""" """
...@@ -266,7 +266,7 @@ class TFFunnelAttentionStructure: ...@@ -266,7 +266,7 @@ class TFFunnelAttentionStructure:
pooled_pos = pos pooled_pos = pos
ref_point = pooled_pos[0] - pos[0] ref_point = pooled_pos[0] - pos[0]
num_remove = shift * tf.cast(shape_list(pooled_pos)[0], dtype=ref_point.dtype) num_remove = shift * shape_list(pooled_pos)[0]
max_dist = ref_point + num_remove * stride max_dist = ref_point + num_remove * stride
min_dist = pooled_pos[0] - pos[-1] min_dist = pooled_pos[0] - pos[-1]
...@@ -522,17 +522,13 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer): ...@@ -522,17 +522,13 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
# merge attention scores # merge attention scores
attn_score = content_score + positional_attn + token_type_attn attn_score = content_score + positional_attn + token_type_attn
# precision safe in case of mixed precision training
dtype = attn_score.dtype
if dtype != tf.float32:
attn_score = tf.cast(attn_score, tf.float32)
# perform masking # perform masking
if attention_mask is not None: if attention_mask is not None:
attn_score = attn_score - INF * (1 - tf.cast(attention_mask[:, None, None], tf.float32)) attention_mask = tf.cast(attention_mask, dtype=attn_score.dtype)
attn_score = attn_score - (INF * (1 - attention_mask[:, None, None]))
# attention probability # attention probability
attn_prob = tf.nn.softmax(attn_score, axis=-1) attn_prob = tf.nn.softmax(attn_score, axis=-1)
if dtype != tf.float32:
attn_prob = tf.cast(attn_prob, dtype)
attn_prob = self.attention_dropout(attn_prob, training=training) attn_prob = self.attention_dropout(attn_prob, training=training)
# attention output, shape batch_size x seq_len x n_head x d_head # attention output, shape batch_size x seq_len x n_head x d_head
......
...@@ -372,10 +372,6 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -372,10 +372,6 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI # This test is too long (>30sec) and makes fail the CI
pass pass
def test_mixed_precision(self):
# TODO JP: Make Funnel float16 compliant
pass
@require_tf @require_tf
class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase): class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
...@@ -407,7 +403,3 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -407,7 +403,3 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
def test_saved_model_creation(self): def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI # This test is too long (>30sec) and makes fail the CI
pass pass
def test_mixed_precision(self):
# TODO JP: Make Funnel float16 compliant
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment