Unverified Commit 2ee9f9b6 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Fix computation of attention_probs when head_mask is provided. (#9853)



* Fix computation of attention_probs when head_mask is provided.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Apply changes to the template
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent b936582f
......@@ -370,7 +370,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
# Mask heads if we want to
if head_mask is not None:
attention_scores = tf.multiply(attention_scores, head_mask)
attention_probs = tf.multiply(attention_probs, head_mask)
attention_output = tf.einsum("acbe,aecd->abcd", attention_probs, value_layer)
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
......
......@@ -253,7 +253,7 @@ class TFElectraSelfAttention(tf.keras.layers.Layer):
# Mask heads if we want to
if head_mask is not None:
attention_scores = tf.multiply(attention_scores, head_mask)
attention_probs = tf.multiply(attention_probs, head_mask)
attention_output = tf.einsum("acbe,aecd->abcd", attention_probs, value_layer)
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
......
......@@ -377,7 +377,7 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
# Mask heads if we want to
if head_mask is not None:
attention_scores = tf.multiply(attention_scores, head_mask)
attention_probs = tf.multiply(attention_probs, head_mask)
attention_output = tf.einsum("acbe,aecd->abcd", attention_probs, value_layer)
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
......
......@@ -317,7 +317,7 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
# Mask heads if we want to
if head_mask is not None:
attention_scores = tf.multiply(attention_scores, head_mask)
attention_probs = tf.multiply(attention_probs, head_mask)
attention_output = tf.einsum("acbe,aecd->abcd", attention_probs, value_layer)
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment