Unverified Commit 0eaeae2e authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Fix a condition in test_generate_with_head_masking (#11911)

* Fix a condition in test_generate_with_head_masking

* Fix usage of head_mask in bigbirg_pegasus

* Fix head masking for speech2text

* Resolve copy mismatch + drop unwanted print statement

* Fix the condition
parent bebbdd0f
...@@ -1174,6 +1174,8 @@ class BigBirdPegasusEncoderAttention(nn.Module): ...@@ -1174,6 +1174,8 @@ class BigBirdPegasusEncoderAttention(nn.Module):
from_blocked_mask=None, from_blocked_mask=None,
to_blocked_mask=None, to_blocked_mask=None,
): ):
# Expand dims to enable multiplication in the self-attention module
head_mask = head_mask.reshape(1, -1, 1, 1) if head_mask is not None else None
if self.config.attention_type == "original_full": if self.config.attention_type == "original_full":
self_outputs = self.self( self_outputs = self.self(
...@@ -1372,6 +1374,7 @@ class BigBirdPegasusEncoderLayer(nn.Module): ...@@ -1372,6 +1374,7 @@ class BigBirdPegasusEncoderLayer(nn.Module):
self_attention_outputs = self.self_attn( self_attention_outputs = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
band_mask=band_mask, band_mask=band_mask,
from_mask=from_mask, from_mask=from_mask,
......
...@@ -1352,6 +1352,8 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel): ...@@ -1352,6 +1352,8 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
past=None, past=None,
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None, use_cache=None,
encoder_outputs=None, encoder_outputs=None,
**kwargs **kwargs
...@@ -1366,6 +1368,8 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel): ...@@ -1366,6 +1368,8 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
......
...@@ -1095,16 +1095,17 @@ class GenerationTesterMixin: ...@@ -1095,16 +1095,17 @@ class GenerationTesterMixin:
signature = inspect.signature(model.forward) signature = inspect.signature(model.forward)
# We want to test only models where encoder/decoder head masking is implemented # We want to test only models where encoder/decoder head masking is implemented
if set(head_masking.keys()) < set([*signature.parameters.keys()]): if not set(head_masking.keys()) < set([*signature.parameters.keys()]):
continue continue
for attn_name, (name, mask) in zip(attention_names, head_masking.items()): for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
out = model.generate( out = model.generate(
input_ids, input_ids,
attention_mask=attention_mask,
num_beams=1, num_beams=1,
max_length=max_length,
output_attentions=True, output_attentions=True,
return_dict_in_generate=True, return_dict_in_generate=True,
remove_invalid_values=True,
**{name: mask}, **{name: mask},
) )
# We check the state of decoder_attentions and cross_attentions just from the last step # We check the state of decoder_attentions and cross_attentions just from the last step
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment