Unverified Commit 15c68c67 authored by samuelpullely's avatar samuelpullely Committed by GitHub
Browse files

Enable `decoder_attention_mask` in `generate` function (#20726)

* Enable `decoder_attention_mask` in `generate` function

* Make style corrections

* Run `make repo-consistency`

* Add integration test
parent a9653400
...@@ -666,6 +666,9 @@ class GenerationMixin: ...@@ -666,6 +666,9 @@ class GenerationMixin:
expand_size, dim=0 expand_size, dim=0
) )
model_kwargs["encoder_outputs"] = encoder_outputs model_kwargs["encoder_outputs"] = encoder_outputs
decoder_attention_mask = model_kwargs.get("decoder_attention_mask")
if decoder_attention_mask is not None:
model_kwargs["decoder_attention_mask"] = decoder_attention_mask.repeat_interleave(expand_size, dim=0)
return input_ids, model_kwargs return input_ids, model_kwargs
...@@ -701,13 +704,21 @@ class GenerationMixin: ...@@ -701,13 +704,21 @@ class GenerationMixin:
token_type_ids = model_kwargs["token_type_ids"] token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
# update attention mask
if not is_encoder_decoder: if not is_encoder_decoder:
# update attention mask
if "attention_mask" in model_kwargs: if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"] attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat( model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
) )
else:
# update decoder attention mask
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
model_kwargs["decoder_attention_mask"] = torch.cat(
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
dim=-1,
)
return model_kwargs return model_kwargs
......
...@@ -1420,6 +1420,7 @@ class BartForConditionalGeneration(BartPretrainedModel): ...@@ -1420,6 +1420,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
decoder_input_ids, decoder_input_ids,
past=None, past=None,
attention_mask=None, attention_mask=None,
decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None, cross_attn_head_mask=None,
...@@ -1437,6 +1438,7 @@ class BartForConditionalGeneration(BartPretrainedModel): ...@@ -1437,6 +1438,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
"past_key_values": past, "past_key_values": past,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask, "cross_attn_head_mask": cross_attn_head_mask,
......
...@@ -2619,6 +2619,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): ...@@ -2619,6 +2619,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
decoder_input_ids, decoder_input_ids,
past=None, past=None,
attention_mask=None, attention_mask=None,
decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None, cross_attn_head_mask=None,
...@@ -2636,6 +2637,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): ...@@ -2636,6 +2637,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
"past_key_values": past, "past_key_values": past,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask, "cross_attn_head_mask": cross_attn_head_mask,
......
...@@ -1226,6 +1226,36 @@ class BartModelIntegrationTests(unittest.TestCase): ...@@ -1226,6 +1226,36 @@ class BartModelIntegrationTests(unittest.TestCase):
], ],
) )
@slow
def test_decoder_attention_mask(self):
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", forced_bos_token_id=0).to(
torch_device
)
tokenizer = self.default_tokenizer
sentence = "UN Chief Says There Is No <mask> in Syria"
input_ids = tokenizer(sentence, return_tensors="pt").input_ids.to(torch_device)
padding_size = 3
decoder_input_ids = torch.tensor(
[
[model.config.decoder_start_token_id]
+ padding_size * [model.config.pad_token_id]
+ [model.config.bos_token_id]
],
dtype=torch.long,
device=torch_device,
)
decoder_attention_mask = torch.where(decoder_input_ids == model.config.pad_token_id, 0, 1).to(torch_device)
generated_ids = model.generate(
input_ids=input_ids,
use_cache=False,
max_new_tokens=20,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
generated_sentence = tokenizer.batch_decode(generated_ids)[0]
expected_sentence = "</s><pad><pad><pad><s>UN Chief Says There Is No Plan B for Peace in Syria</s>"
self.assertEqual(generated_sentence, expected_sentence)
class BartStandaloneDecoderModelTester: class BartStandaloneDecoderModelTester:
def __init__( def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment