"tests/quantization/vscode:/vscode.git/clone" did not exist on "04af4b90d6a7951e2cbad00649af4b8cf2fc90c8"
Unverified Commit 936b3fde authored by Kamal Raj's avatar Kamal Raj Committed by GitHub
Browse files

Update modeling_tf_deberta.py (#13654)

Fixed expand_dims axis
parent 04976a32
...@@ -663,7 +663,7 @@ class TFDebertaDisentangledSelfAttention(tf.keras.layers.Layer): ...@@ -663,7 +663,7 @@ class TFDebertaDisentangledSelfAttention(tf.keras.layers.Layer):
if len(shape_list_pos) == 2: if len(shape_list_pos) == 2:
relative_pos = tf.expand_dims(tf.expand_dims(relative_pos, 0), 0) relative_pos = tf.expand_dims(tf.expand_dims(relative_pos, 0), 0)
elif len(shape_list_pos) == 3: elif len(shape_list_pos) == 3:
relative_pos = tf.expand_dims(relative_pos, 0) relative_pos = tf.expand_dims(relative_pos, 1)
# bxhxqxk # bxhxqxk
elif len(shape_list_pos) != 4: elif len(shape_list_pos) != 4:
raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {len(shape_list_pos)}") raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {len(shape_list_pos)}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment