"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "47735f5f0f2752500d115d2f6bd57816032599b6"
Unverified Commit 936b3fde authored by Kamal Raj's avatar Kamal Raj Committed by GitHub
Browse files

Update modeling_tf_deberta.py (#13654)

Fixed expand_dims axis
parent 04976a32
......@@ -663,7 +663,7 @@ class TFDebertaDisentangledSelfAttention(tf.keras.layers.Layer):
if len(shape_list_pos) == 2:
relative_pos = tf.expand_dims(tf.expand_dims(relative_pos, 0), 0)
elif len(shape_list_pos) == 3:
relative_pos = tf.expand_dims(relative_pos, 0)
relative_pos = tf.expand_dims(relative_pos, 1)
# bxhxqxk
elif len(shape_list_pos) != 4:
raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {len(shape_list_pos)}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment