"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c99fe0386be118bceaab1c85cdb8309eb8cb8208"
Unverified Commit 52d62e68 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix TF Funnel (#9300)

* Fix Funnel

* Apply Patrick's comment

* Remove comment

* Fix dummy value

* Apply style
parent 748006c0
......@@ -185,7 +185,7 @@ class TFFunnelAttentionStructure:
# inputs_embeds has shape batch_size x seq_len x d_model
# attention_mask and token_type_ids have shape batch_size x seq_len
self.pooling_mult = 1
self.seq_len = seq_len = inputs_embeds.shape[1]
self.seq_len = seq_len = shape_list(inputs_embeds)[1]
position_embeds = self.get_position_embeds(seq_len, dtype=inputs_embeds.dtype, training=training)
token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
cls_mask = (
......@@ -241,7 +241,7 @@ class TFFunnelAttentionStructure:
inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2)))
# Maximum relative positions for the first input
rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0, dtype=dtype)
zero_offset = seq_len * 2
zero_offset = seq_len * tf.constant(2)
sinusoid = tf.einsum("i,d->id", rel_pos_id, inv_freq)
sin_embed = self.sin_dropout(tf.sin(sinusoid), training=training)
cos_embed = self.cos_dropout(tf.cos(sinusoid), training=training)
......@@ -257,9 +257,9 @@ class TFFunnelAttentionStructure:
# For block_index = 0 we only need the second one and leave the first one as None.
# First type
if block_index == 0:
position_embeds_pooling = None
else:
position_embeds_pooling = tf.fill([1], value=-1.0)
if block_index != 0:
pooled_pos = self.stride_pool_pos(pos, block_index)
# construct rel_pos_id
......@@ -267,6 +267,7 @@ class TFFunnelAttentionStructure:
rel_pos = self.relative_pos(pos, stride, pooled_pos, shift=2)
# rel_pos = tf.expand_dims(rel_pos,1) + zero_offset
# rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model))
rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype)
rel_pos = rel_pos + zero_offset
position_embeds_pooling = tf.gather(pos_embed, rel_pos, axis=0)
......@@ -277,6 +278,7 @@ class TFFunnelAttentionStructure:
# rel_pos = tf.expand_dims(rel_pos,1) + zero_offset
# rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model))
rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype)
rel_pos = rel_pos + zero_offset
position_embeds_no_pooling = tf.gather(pos_embed, rel_pos, axis=0)
......@@ -298,7 +300,7 @@ class TFFunnelAttentionStructure:
else:
return pos_id[::2]
def relative_pos(self, pos, stride, pooled_pos=None, shift=1):
def relative_pos(self, pos, stride, pooled_pos=None, shift=1.0):
"""
Build the relative positional vector between `pos` and `pooled_pos`.
"""
......@@ -306,11 +308,11 @@ class TFFunnelAttentionStructure:
pooled_pos = pos
ref_point = pooled_pos[0] - pos[0]
num_remove = shift * pooled_pos.shape[0]
num_remove = shift * tf.cast(shape_list(pooled_pos)[0], dtype=ref_point.dtype)
max_dist = ref_point + num_remove * stride
min_dist = pooled_pos[0] - pos[-1]
return tf.range(max_dist, min_dist - 1, -stride, dtype=tf.int64)
return tf.range(max_dist, min_dist - 1, -stride)
def stride_pool(self, tensor, axis):
"""
......@@ -330,7 +332,7 @@ class TFFunnelAttentionStructure:
return type(tensor)(self.stride_pool(x, axis) for x in tensor)
# Deal with negative axis
axis %= tensor.shape.ndims
axis %= len(shape_list(tensor))
axis_slice = slice(None, -1, 2) if self.separate_cls and self.truncate_seq else slice(None, None, 2)
enc_slice = [slice(None)] * axis + [axis_slice]
......@@ -352,7 +354,7 @@ class TFFunnelAttentionStructure:
suffix = tensor[:, :-1] if self.truncate_seq else tensor
tensor = tf.concat([tensor[:, :1], suffix], axis=1)
ndim = tensor.shape.ndims
ndim = len(shape_list(tensor))
if ndim == 2:
tensor = tensor[:, :, None]
......@@ -485,10 +487,14 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
"bind,jd->bnij", q_r_attention_2, omega
)
else:
shift = 2 if q_head.shape[1] != context_len else 1
# Notations from the paper, appending A.2.1, final formula (https://arxiv.org/abs/2006.03236)
# Grab the proper positional encoding, shape max_rel_len x d_model
r = position_embeds[self.block_index][shift - 1]
if shape_list(q_head)[1] != context_len:
shift = 2
r = position_embeds[self.block_index][1]
else:
shift = 1
r = position_embeds[self.block_index][0]
# Shape n_head x d_head
v = self.r_r_bias * self.scale
# Shape d_model x n_head x d_head
......@@ -517,7 +523,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
# Shape batch_size x n_head x seq_len x 2
token_type_bias = tf.einsum("bind,snd->bnis", q_head + r_s_bias, self.seg_embed)
# Shape batch_size x n_head x seq_len x context_len
new_shape = [batch_size, q_head.shape[2], seq_len, context_len]
new_shape = [batch_size, shape_list(q_head)[2], seq_len, context_len]
token_type_mat = tf.broadcast_to(token_type_mat[:, None], new_shape)
# Shapes batch_size x n_head x seq_len
diff_token_type, same_token_type = tf.split(token_type_bias, 2, axis=-1)
......@@ -536,7 +542,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
batch_size, seq_len, _ = shape_list(query)
context_len = key.shape[1]
context_len = shape_list(key)[1]
n_head, d_head = self.n_head, self.d_head
# Shape batch_size x seq_len x n_head x d_head
......@@ -652,10 +658,13 @@ class TFFunnelEncoder(tf.keras.layers.Layer):
for block_index, block in enumerate(self.blocks):
pooling_flag = shape_list(hidden)[1] > (2 if self.separate_cls else 1)
pooling_flag = pooling_flag and block_index > 0
pooled_hidden = tf.zeros(shape_list(hidden))
if pooling_flag:
pooled_hidden, attention_inputs = self.attention_structure.pre_attention_pooling(
hidden, attention_inputs
)
for (layer_index, layer) in enumerate(block):
for repeat_index in range(self.block_repeats[block_index]):
do_pooling = (repeat_index == 0) and (layer_index == 0) and pooling_flag
......@@ -724,7 +733,7 @@ class TFFunnelDecoder(tf.keras.layers.Layer):
upsampled_hidden = upsample(
final_hidden,
stride=self.stride,
target_len=first_block_hidden.shape[1],
target_len=shape_list(first_block_hidden)[1],
separate_cls=self.separate_cls,
truncate_seq=self.truncate_seq,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment