Unverified Commit 07505358 authored by Matt's avatar Matt Committed by GitHub
Browse files

Change how `take_along_axis` is computed in DeBERTa to stop confusing XLA (#18256)

* Change how `take_along_axis` is computed in DeBERTa to stop confusing XLA

* Greatly simplify take_along_axis() since the code wasn't using most of it
parent d95a32cc
......@@ -522,26 +522,13 @@ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
return tf.broadcast_to(pos_index, shapes)
def take_along_axis(x, indices, gather_axis):
if gather_axis < 0:
gather_axis = tf.rank(x) + gather_axis
if gather_axis != tf.rank(x) - 1:
pre_roll = tf.rank(x) - 1 - gather_axis
permutation = tf.roll(tf.range(tf.rank(x)), pre_roll, axis=0)
x = tf.transpose(x, perm=permutation)
indices = tf.transpose(indices, perm=permutation)
else:
pre_roll = 0
def take_along_axis(x, indices):
# Only a valid port of np.take_along_axis when the gather axis is -1
flat_x = tf.reshape(x, (-1, tf.shape(x)[-1]))
flat_indices = tf.reshape(indices, (-1, tf.shape(indices)[-1]))
flat_x = tf.reshape(x, (-1, x.shape[-1]))
flat_indices = tf.reshape(indices, (-1, indices.shape[-1]))
gathered = tf.gather(flat_x, flat_indices, batch_dims=1)
gathered = tf.reshape(gathered, tf.shape(indices))
if pre_roll != 0:
permutation = tf.roll(tf.range(tf.rank(x)), -pre_roll, axis=0)
gathered = tf.transpose(gathered, perm=permutation)
gathered = tf.reshape(gathered, indices.shape)
return gathered
......@@ -775,7 +762,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
tf.squeeze(c2p_pos, 0),
[shape_list(query_layer)[0], shape_list(query_layer)[1], shape_list(relative_pos)[-1]],
),
-1,
)
score += c2p_att / scale
......@@ -803,7 +789,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
tf.squeeze(p2c_pos, 0),
[shape_list(query_layer)[0], shape_list(key_layer)[-2], shape_list(key_layer)[-2]],
),
-1,
),
[0, 2, 1],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment