Unverified Commit 0eaeae2e authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Fix a condition in test_generate_with_head_masking (#11911)

* Fix a condition in test_generate_with_head_masking

* Fix usage of head_mask in bigbirg_pegasus

* Fix head masking for speech2text

* Resolve copy mismatch + drop unwanted print statement

* Fix the condition
parent bebbdd0f
......@@ -1174,6 +1174,8 @@ class BigBirdPegasusEncoderAttention(nn.Module):
from_blocked_mask=None,
to_blocked_mask=None,
):
# Expand dims to enable multiplication in the self-attention module
head_mask = head_mask.reshape(1, -1, 1, 1) if head_mask is not None else None
if self.config.attention_type == "original_full":
self_outputs = self.self(
......@@ -1372,6 +1374,7 @@ class BigBirdPegasusEncoderLayer(nn.Module):
self_attention_outputs = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=layer_head_mask,
output_attentions=output_attentions,
band_mask=band_mask,
from_mask=from_mask,
......
......@@ -1352,6 +1352,8 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs
......@@ -1366,6 +1368,8 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
......
......@@ -1095,16 +1095,17 @@ class GenerationTesterMixin:
signature = inspect.signature(model.forward)
# We want to test only models where encoder/decoder head masking is implemented
if set(head_masking.keys()) < set([*signature.parameters.keys()]):
if not set(head_masking.keys()) < set([*signature.parameters.keys()]):
continue
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
out = model.generate(
input_ids,
attention_mask=attention_mask,
num_beams=1,
max_length=max_length,
output_attentions=True,
return_dict_in_generate=True,
remove_invalid_values=True,
**{name: mask},
)
# We check the state of decoder_attentions and cross_attentions just from the last step
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment