Unverified Commit e3ff165a authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Fix cross-attention head mask for Torch encoder-decoder models (#10605)

* Fix cross-attention head mask for Torch BART models

* Fix head masking for cross-attention module for the following
models: BART, Blenderbot, Blenderbot_small, M2M_100, Marian, MBart,
Pegasus

* Enable test_headmasking for M2M_100 model

* Fix cross_head_mask for FSMT, LED and T5

* This commit fixes `head_mask` for cross-attention modules
in the following models: FSMT, LED, T5

* It also contains some smaller changes in doc so that
it is be perfectly clear the shape of `cross_head_mask`
is the same as of `decoder_head_mask`

* Update template

* Fix template for BartForCausalLM

* Fix cross_head_mask for Speech2Text models

* Fix cross_head_mask in templates

* Fix args order in BartForCausalLM template

* Fix doc in BART templates

* Make more explicit naming

* `cross_head_mask` -> `cross_attn_head_mask`

* `cross_layer_head_mask` -> `cross_attn_layer_head_mask`

* Fix doc

* make style quality

* Fix speech2text docstring
parent ca6b80ca
...@@ -296,7 +296,7 @@ class BartEncoderLayer(nn.Module): ...@@ -296,7 +296,7 @@ class BartEncoderLayer(nn.Module):
attention_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail. returned tensors for more detail.
...@@ -368,7 +368,7 @@ class BartDecoderLayer(nn.Module): ...@@ -368,7 +368,7 @@ class BartDecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
...@@ -382,9 +382,9 @@ class BartDecoderLayer(nn.Module): ...@@ -382,9 +382,9 @@ class BartDecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
...@@ -419,7 +419,7 @@ class BartDecoderLayer(nn.Module): ...@@ -419,7 +419,7 @@ class BartDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -598,18 +598,25 @@ BART_INPUTS_DOCSTRING = r""" ...@@ -598,18 +598,25 @@ BART_INPUTS_DOCSTRING = r"""
If you want to change padding behavior, you should read :func:`modeling_bart._prepare_decoder_inputs` and If you want to change padding behavior, you should read :func:`modeling_bart._prepare_decoder_inputs` and
modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
information on the default strategy. information on the default strategy.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
...@@ -710,11 +717,11 @@ class BartEncoder(BartPretrainedModel): ...@@ -710,11 +717,11 @@ class BartEncoder(BartPretrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
...@@ -875,7 +882,7 @@ class BartDecoder(BartPretrainedModel): ...@@ -875,7 +882,7 @@ class BartDecoder(BartPretrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
...@@ -912,18 +919,18 @@ class BartDecoder(BartPretrainedModel): ...@@ -912,18 +919,18 @@ class BartDecoder(BartPretrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
on hidden heads. Mask values selected in ``[0, 1]``: cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -993,11 +1000,12 @@ class BartDecoder(BartPretrainedModel): ...@@ -993,11 +1000,12 @@ class BartDecoder(BartPretrainedModel):
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
...@@ -1031,7 +1039,7 @@ class BartDecoder(BartPretrainedModel): ...@@ -1031,7 +1039,7 @@ class BartDecoder(BartPretrainedModel):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
...@@ -1042,7 +1050,9 @@ class BartDecoder(BartPretrainedModel): ...@@ -1042,7 +1050,9 @@ class BartDecoder(BartPretrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
...@@ -1123,6 +1133,7 @@ class BartModel(BartPretrainedModel): ...@@ -1123,6 +1133,7 @@ class BartModel(BartPretrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1172,7 +1183,7 @@ class BartModel(BartPretrainedModel): ...@@ -1172,7 +1183,7 @@ class BartModel(BartPretrainedModel):
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
...@@ -1248,6 +1259,7 @@ class BartForConditionalGeneration(BartPretrainedModel): ...@@ -1248,6 +1259,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1282,6 +1294,7 @@ class BartForConditionalGeneration(BartPretrainedModel): ...@@ -1282,6 +1294,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -1386,6 +1399,7 @@ class BartForSequenceClassification(BartPretrainedModel): ...@@ -1386,6 +1399,7 @@ class BartForSequenceClassification(BartPretrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
inputs_embeds=None, inputs_embeds=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
...@@ -1416,6 +1430,7 @@ class BartForSequenceClassification(BartPretrainedModel): ...@@ -1416,6 +1430,7 @@ class BartForSequenceClassification(BartPretrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -1496,6 +1511,7 @@ class BartForQuestionAnswering(BartPretrainedModel): ...@@ -1496,6 +1511,7 @@ class BartForQuestionAnswering(BartPretrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
...@@ -1527,6 +1543,7 @@ class BartForQuestionAnswering(BartPretrainedModel): ...@@ -1527,6 +1543,7 @@ class BartForQuestionAnswering(BartPretrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -1633,7 +1650,7 @@ class BartForCausalLM(BartPretrainedModel): ...@@ -1633,7 +1650,7 @@ class BartForCausalLM(BartPretrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
...@@ -1666,18 +1683,17 @@ class BartForCausalLM(BartPretrainedModel): ...@@ -1666,18 +1683,17 @@ class BartForCausalLM(BartPretrainedModel):
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -1734,7 +1750,7 @@ class BartForCausalLM(BartPretrainedModel): ...@@ -1734,7 +1750,7 @@ class BartForCausalLM(BartPretrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
......
...@@ -298,7 +298,7 @@ class BlenderbotEncoderLayer(nn.Module): ...@@ -298,7 +298,7 @@ class BlenderbotEncoderLayer(nn.Module):
attention_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail. returned tensors for more detail.
...@@ -371,7 +371,7 @@ class BlenderbotDecoderLayer(nn.Module): ...@@ -371,7 +371,7 @@ class BlenderbotDecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
...@@ -385,9 +385,9 @@ class BlenderbotDecoderLayer(nn.Module): ...@@ -385,9 +385,9 @@ class BlenderbotDecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
...@@ -423,7 +423,7 @@ class BlenderbotDecoderLayer(nn.Module): ...@@ -423,7 +423,7 @@ class BlenderbotDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -554,18 +554,25 @@ BLENDERBOT_INPUTS_DOCSTRING = r""" ...@@ -554,18 +554,25 @@ BLENDERBOT_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
...@@ -666,11 +673,11 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): ...@@ -666,11 +673,11 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
...@@ -834,7 +841,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -834,7 +841,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
...@@ -871,18 +878,19 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -871,18 +878,19 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
on hidden heads. Mask values selected in ``[0, 1]``: cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -951,11 +959,12 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -951,11 +959,12 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
...@@ -989,7 +998,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -989,7 +998,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
...@@ -1000,7 +1009,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -1000,7 +1009,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
...@@ -1090,6 +1101,7 @@ class BlenderbotModel(BlenderbotPreTrainedModel): ...@@ -1090,6 +1101,7 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1147,7 +1159,7 @@ class BlenderbotModel(BlenderbotPreTrainedModel): ...@@ -1147,7 +1159,7 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
...@@ -1241,6 +1253,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): ...@@ -1241,6 +1253,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1275,6 +1288,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): ...@@ -1275,6 +1288,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -1395,7 +1409,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel): ...@@ -1395,7 +1409,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
...@@ -1428,18 +1442,17 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel): ...@@ -1428,18 +1442,17 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -1496,7 +1509,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel): ...@@ -1496,7 +1509,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
......
...@@ -296,7 +296,7 @@ class BlenderbotSmallEncoderLayer(nn.Module): ...@@ -296,7 +296,7 @@ class BlenderbotSmallEncoderLayer(nn.Module):
attention_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail. returned tensors for more detail.
...@@ -369,7 +369,7 @@ class BlenderbotSmallDecoderLayer(nn.Module): ...@@ -369,7 +369,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
...@@ -383,9 +383,9 @@ class BlenderbotSmallDecoderLayer(nn.Module): ...@@ -383,9 +383,9 @@ class BlenderbotSmallDecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
...@@ -420,7 +420,7 @@ class BlenderbotSmallDecoderLayer(nn.Module): ...@@ -420,7 +420,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -555,18 +555,25 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r""" ...@@ -555,18 +555,25 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
...@@ -667,11 +674,11 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): ...@@ -667,11 +674,11 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
...@@ -834,7 +841,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -834,7 +841,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
...@@ -871,18 +878,18 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -871,18 +878,18 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
on hidden heads. Mask values selected in ``[0, 1]``: cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -953,10 +960,12 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -953,10 +960,12 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
if head_mask is not None: # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
assert head_mask.size()[0] == ( for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
len(self.layers) if attn_mask is not None:
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." assert attn_mask.size()[0] == (
len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
...@@ -990,7 +999,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -990,7 +999,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
...@@ -1001,7 +1010,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -1001,7 +1010,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
...@@ -1077,6 +1088,7 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel): ...@@ -1077,6 +1088,7 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1134,7 +1146,7 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel): ...@@ -1134,7 +1146,7 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
...@@ -1216,6 +1228,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): ...@@ -1216,6 +1228,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1250,6 +1263,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): ...@@ -1250,6 +1263,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -1370,7 +1384,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel): ...@@ -1370,7 +1384,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
...@@ -1403,18 +1417,17 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel): ...@@ -1403,18 +1417,17 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -1471,7 +1484,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel): ...@@ -1471,7 +1484,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
......
...@@ -248,17 +248,25 @@ FSMT_INPUTS_DOCSTRING = r""" ...@@ -248,17 +248,25 @@ FSMT_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`Tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`Tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
...@@ -573,7 +581,7 @@ class DecoderLayer(nn.Module): ...@@ -573,7 +581,7 @@ class DecoderLayer(nn.Module):
layer_state=None, layer_state=None,
causal_mask=None, causal_mask=None,
layer_head_mask=None, layer_head_mask=None,
encoder_layer_head_mask=None, cross_attn_layer_head_mask=None,
decoder_padding_mask=None, decoder_padding_mask=None,
output_attentions=False, output_attentions=False,
): ):
...@@ -604,7 +612,7 @@ class DecoderLayer(nn.Module): ...@@ -604,7 +612,7 @@ class DecoderLayer(nn.Module):
key=encoder_hidden_states, key=encoder_hidden_states,
key_padding_mask=encoder_attn_mask, key_padding_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state layer_state=layer_state, # mutates layer state
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
...@@ -666,7 +674,7 @@ class FSMTDecoder(nn.Module): ...@@ -666,7 +674,7 @@ class FSMTDecoder(nn.Module):
decoder_padding_mask, decoder_padding_mask,
decoder_causal_mask, decoder_causal_mask,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
...@@ -690,12 +698,11 @@ class FSMTDecoder(nn.Module): ...@@ -690,12 +698,11 @@ class FSMTDecoder(nn.Module):
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
Returns: Returns:
BaseModelOutputWithPast or tuple: BaseModelOutputWithPast or tuple:
...@@ -732,10 +739,11 @@ class FSMTDecoder(nn.Module): ...@@ -732,10 +739,11 @@ class FSMTDecoder(nn.Module):
next_decoder_cache = [] next_decoder_cache = []
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
...@@ -756,7 +764,7 @@ class FSMTDecoder(nn.Module): ...@@ -756,7 +764,7 @@ class FSMTDecoder(nn.Module):
layer_state=layer_state, layer_state=layer_state,
causal_mask=decoder_causal_mask, causal_mask=decoder_causal_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -1009,6 +1017,7 @@ class FSMTModel(PretrainedFSMTModel): ...@@ -1009,6 +1017,7 @@ class FSMTModel(PretrainedFSMTModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Tuple] = None, encoder_outputs: Optional[Tuple] = None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
...@@ -1065,7 +1074,7 @@ class FSMTModel(PretrainedFSMTModel): ...@@ -1065,7 +1074,7 @@ class FSMTModel(PretrainedFSMTModel):
decoder_padding_mask, decoder_padding_mask,
decoder_causal_mask=causal_mask, decoder_causal_mask=causal_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -1143,6 +1152,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): ...@@ -1143,6 +1152,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
labels=None, labels=None,
...@@ -1173,6 +1183,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): ...@@ -1173,6 +1183,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
......
...@@ -901,7 +901,7 @@ class LEDEncoderLayer(nn.Module): ...@@ -901,7 +901,7 @@ class LEDEncoderLayer(nn.Module):
attention_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
""" """
residual = hidden_states residual = hidden_states
attn_outputs = self.self_attn( attn_outputs = self.self_attn(
...@@ -968,7 +968,7 @@ class LEDDecoderLayer(nn.Module): ...@@ -968,7 +968,7 @@ class LEDDecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
...@@ -982,9 +982,9 @@ class LEDDecoderLayer(nn.Module): ...@@ -982,9 +982,9 @@ class LEDDecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(decoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
size `(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`): Whether the base model outputs attentions. output_attentions (:obj:`bool`): Whether the base model outputs attentions.
This requires the attentions tensor to be reshaped in this function. This requires the attentions tensor to be reshaped in this function.
...@@ -1018,7 +1018,7 @@ class LEDDecoderLayer(nn.Module): ...@@ -1018,7 +1018,7 @@ class LEDDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -1199,17 +1199,6 @@ class LEDSeq2SeqModelOutput(ModelOutput): ...@@ -1199,17 +1199,6 @@ class LEDSeq2SeqModelOutput(ModelOutput):
Global attentions weights after the attention softmax, used to compute the weighted average in the Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence. in the sequence.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
""" """
last_hidden_state: torch.FloatTensor = None last_hidden_state: torch.FloatTensor = None
...@@ -1221,8 +1210,6 @@ class LEDSeq2SeqModelOutput(ModelOutput): ...@@ -1221,8 +1210,6 @@ class LEDSeq2SeqModelOutput(ModelOutput):
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
head_mask: Optional[torch.FloatTensor] = None
decoder_head_mask: Optional[torch.FloatTensor] = None
@dataclass @dataclass
...@@ -1278,17 +1265,6 @@ class LEDSeq2SeqLMOutput(ModelOutput): ...@@ -1278,17 +1265,6 @@ class LEDSeq2SeqLMOutput(ModelOutput):
Global attentions weights after the attention softmax, used to compute the weighted average in the Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence. in the sequence.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
...@@ -1301,8 +1277,6 @@ class LEDSeq2SeqLMOutput(ModelOutput): ...@@ -1301,8 +1277,6 @@ class LEDSeq2SeqLMOutput(ModelOutput):
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
head_mask: Optional[torch.FloatTensor] = None
decoder_head_mask: Optional[torch.FloatTensor] = None
@dataclass @dataclass
...@@ -1358,17 +1332,6 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput): ...@@ -1358,17 +1332,6 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput):
Global attentions weights after the attention softmax, used to compute the weighted average in the Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence. in the sequence.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
...@@ -1381,8 +1344,6 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput): ...@@ -1381,8 +1344,6 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput):
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
head_mask: Optional[torch.FloatTensor] = None
decoder_head_mask: Optional[torch.FloatTensor] = None
@dataclass @dataclass
...@@ -1440,17 +1401,6 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput): ...@@ -1440,17 +1401,6 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
Global attentions weights after the attention softmax, used to compute the weighted average in the Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence. in the sequence.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
...@@ -1464,8 +1414,6 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput): ...@@ -1464,8 +1414,6 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
head_mask: Optional[torch.FloatTensor] = None
decoder_head_mask: Optional[torch.FloatTensor] = None
LED_START_DOCSTRING = r""" LED_START_DOCSTRING = r"""
...@@ -1547,17 +1495,24 @@ LED_INPUTS_DOCSTRING = r""" ...@@ -1547,17 +1495,24 @@ LED_INPUTS_DOCSTRING = r"""
- 0 for local attention (a sliding window attention), - 0 for local attention (a sliding window attention),
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
...@@ -1730,7 +1685,7 @@ class LEDEncoder(LEDPreTrainedModel): ...@@ -1730,7 +1685,7 @@ class LEDEncoder(LEDPreTrainedModel):
- 0 for local attention (a sliding window attention), - 0 for local attention (a sliding window attention),
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
...@@ -1914,7 +1869,7 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -1914,7 +1869,7 @@ class LEDDecoder(LEDPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
...@@ -1961,18 +1916,17 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -1961,18 +1916,17 @@ class LEDDecoder(LEDPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -2052,11 +2006,12 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -2052,11 +2006,12 @@ class LEDDecoder(LEDPreTrainedModel):
all_cross_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
...@@ -2090,7 +2045,7 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -2090,7 +2045,7 @@ class LEDDecoder(LEDPreTrainedModel):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
...@@ -2100,7 +2055,9 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -2100,7 +2055,9 @@ class LEDDecoder(LEDPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
...@@ -2180,6 +2137,7 @@ class LEDModel(LEDPreTrainedModel): ...@@ -2180,6 +2137,7 @@ class LEDModel(LEDPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
global_attention_mask=None, global_attention_mask=None,
past_key_values=None, past_key_values=None,
...@@ -2224,7 +2182,7 @@ class LEDModel(LEDPreTrainedModel): ...@@ -2224,7 +2182,7 @@ class LEDModel(LEDPreTrainedModel):
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
...@@ -2306,6 +2264,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): ...@@ -2306,6 +2264,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
global_attention_mask=None, global_attention_mask=None,
past_key_values=None, past_key_values=None,
...@@ -2358,6 +2317,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): ...@@ -2358,6 +2317,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -2463,6 +2423,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel): ...@@ -2463,6 +2423,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
global_attention_mask=None, global_attention_mask=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -2495,6 +2456,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel): ...@@ -2495,6 +2456,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -2571,6 +2533,7 @@ class LEDForQuestionAnswering(LEDPreTrainedModel): ...@@ -2571,6 +2533,7 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
global_attention_mask=None, global_attention_mask=None,
start_positions=None, start_positions=None,
...@@ -2604,6 +2567,7 @@ class LEDForQuestionAnswering(LEDPreTrainedModel): ...@@ -2604,6 +2567,7 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
......
...@@ -367,7 +367,7 @@ class M2M100EncoderLayer(nn.Module): ...@@ -367,7 +367,7 @@ class M2M100EncoderLayer(nn.Module):
attention_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail. returned tensors for more detail.
...@@ -440,7 +440,7 @@ class M2M100DecoderLayer(nn.Module): ...@@ -440,7 +440,7 @@ class M2M100DecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
...@@ -454,9 +454,9 @@ class M2M100DecoderLayer(nn.Module): ...@@ -454,9 +454,9 @@ class M2M100DecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
...@@ -492,7 +492,7 @@ class M2M100DecoderLayer(nn.Module): ...@@ -492,7 +492,7 @@ class M2M100DecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -603,6 +603,24 @@ M2M_100_INPUTS_DOCSTRING = r""" ...@@ -603,6 +603,24 @@ M2M_100_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default.
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
...@@ -704,6 +722,12 @@ class M2M100Encoder(M2M100PreTrainedModel): ...@@ -704,6 +722,12 @@ class M2M100Encoder(M2M100PreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
...@@ -841,7 +865,7 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -841,7 +865,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
...@@ -878,6 +902,19 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -878,6 +902,19 @@ class M2M100Decoder(M2M100PreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding. decoding.
...@@ -955,11 +992,12 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -955,11 +992,12 @@ class M2M100Decoder(M2M100PreTrainedModel):
all_cross_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
...@@ -993,7 +1031,7 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -993,7 +1031,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
...@@ -1004,7 +1042,9 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -1004,7 +1042,9 @@ class M2M100Decoder(M2M100PreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
...@@ -1085,6 +1125,7 @@ class M2M100Model(M2M100PreTrainedModel): ...@@ -1085,6 +1125,7 @@ class M2M100Model(M2M100PreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1126,7 +1167,7 @@ class M2M100Model(M2M100PreTrainedModel): ...@@ -1126,7 +1167,7 @@ class M2M100Model(M2M100PreTrainedModel):
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
...@@ -1201,6 +1242,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): ...@@ -1201,6 +1242,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1249,6 +1291,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): ...@@ -1249,6 +1291,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -1281,7 +1324,14 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): ...@@ -1281,7 +1324,14 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs self,
decoder_input_ids,
past=None,
attention_mask=None,
head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
): ):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past is not None: if past is not None:
...@@ -1293,6 +1343,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): ...@@ -1293,6 +1343,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
"past_key_values": past, "past_key_values": past,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
......
...@@ -313,7 +313,7 @@ class MarianEncoderLayer(nn.Module): ...@@ -313,7 +313,7 @@ class MarianEncoderLayer(nn.Module):
attention_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail. returned tensors for more detail.
...@@ -386,7 +386,7 @@ class MarianDecoderLayer(nn.Module): ...@@ -386,7 +386,7 @@ class MarianDecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
...@@ -400,9 +400,9 @@ class MarianDecoderLayer(nn.Module): ...@@ -400,9 +400,9 @@ class MarianDecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
...@@ -437,7 +437,7 @@ class MarianDecoderLayer(nn.Module): ...@@ -437,7 +437,7 @@ class MarianDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -567,18 +567,25 @@ MARIAN_INPUTS_DOCSTRING = r""" ...@@ -567,18 +567,25 @@ MARIAN_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
...@@ -678,11 +685,11 @@ class MarianEncoder(MarianPreTrainedModel): ...@@ -678,11 +685,11 @@ class MarianEncoder(MarianPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
...@@ -842,7 +849,7 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -842,7 +849,7 @@ class MarianDecoder(MarianPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
...@@ -879,18 +886,18 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -879,18 +886,18 @@ class MarianDecoder(MarianPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
on hidden heads. Mask values selected in ``[0, 1]``: cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -959,11 +966,12 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -959,11 +966,12 @@ class MarianDecoder(MarianPreTrainedModel):
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
...@@ -997,7 +1005,7 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -997,7 +1005,7 @@ class MarianDecoder(MarianPreTrainedModel):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
...@@ -1008,7 +1016,9 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -1008,7 +1016,9 @@ class MarianDecoder(MarianPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
...@@ -1084,6 +1094,7 @@ class MarianModel(MarianPreTrainedModel): ...@@ -1084,6 +1094,7 @@ class MarianModel(MarianPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1142,7 +1153,7 @@ class MarianModel(MarianPreTrainedModel): ...@@ -1142,7 +1153,7 @@ class MarianModel(MarianPreTrainedModel):
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
...@@ -1229,6 +1240,7 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1229,6 +1240,7 @@ class MarianMTModel(MarianPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1264,6 +1276,7 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1264,6 +1276,7 @@ class MarianMTModel(MarianPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -1391,7 +1404,7 @@ class MarianForCausalLM(MarianPreTrainedModel): ...@@ -1391,7 +1404,7 @@ class MarianForCausalLM(MarianPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
...@@ -1424,18 +1437,17 @@ class MarianForCausalLM(MarianPreTrainedModel): ...@@ -1424,18 +1437,17 @@ class MarianForCausalLM(MarianPreTrainedModel):
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -1492,7 +1504,7 @@ class MarianForCausalLM(MarianPreTrainedModel): ...@@ -1492,7 +1504,7 @@ class MarianForCausalLM(MarianPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
......
...@@ -303,7 +303,7 @@ class MBartEncoderLayer(nn.Module): ...@@ -303,7 +303,7 @@ class MBartEncoderLayer(nn.Module):
attention_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail. returned tensors for more detail.
...@@ -375,7 +375,7 @@ class MBartDecoderLayer(nn.Module): ...@@ -375,7 +375,7 @@ class MBartDecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
...@@ -389,9 +389,9 @@ class MBartDecoderLayer(nn.Module): ...@@ -389,9 +389,9 @@ class MBartDecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
...@@ -427,7 +427,7 @@ class MBartDecoderLayer(nn.Module): ...@@ -427,7 +427,7 @@ class MBartDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -595,18 +595,25 @@ MBART_INPUTS_DOCSTRING = r""" ...@@ -595,18 +595,25 @@ MBART_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
...@@ -708,11 +715,11 @@ class MBartEncoder(MBartPreTrainedModel): ...@@ -708,11 +715,11 @@ class MBartEncoder(MBartPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
...@@ -877,7 +884,7 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -877,7 +884,7 @@ class MBartDecoder(MBartPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
...@@ -914,18 +921,18 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -914,18 +921,18 @@ class MBartDecoder(MBartPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
on hidden heads. Mask values selected in ``[0, 1]``: cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -995,11 +1002,12 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -995,11 +1002,12 @@ class MBartDecoder(MBartPreTrainedModel):
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
...@@ -1033,7 +1041,7 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1033,7 +1041,7 @@ class MBartDecoder(MBartPreTrainedModel):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
...@@ -1044,7 +1052,9 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1044,7 +1052,9 @@ class MBartDecoder(MBartPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
...@@ -1127,6 +1137,7 @@ class MBartModel(MBartPreTrainedModel): ...@@ -1127,6 +1137,7 @@ class MBartModel(MBartPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1173,7 +1184,7 @@ class MBartModel(MBartPreTrainedModel): ...@@ -1173,7 +1184,7 @@ class MBartModel(MBartPreTrainedModel):
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
...@@ -1254,6 +1265,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): ...@@ -1254,6 +1265,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1287,6 +1299,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): ...@@ -1287,6 +1299,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -1384,6 +1397,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel): ...@@ -1384,6 +1397,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
inputs_embeds=None, inputs_embeds=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
...@@ -1414,6 +1428,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel): ...@@ -1414,6 +1428,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -1495,6 +1510,7 @@ class MBartForQuestionAnswering(MBartPreTrainedModel): ...@@ -1495,6 +1510,7 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
...@@ -1526,6 +1542,7 @@ class MBartForQuestionAnswering(MBartPreTrainedModel): ...@@ -1526,6 +1542,7 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -1634,7 +1651,7 @@ class MBartForCausalLM(MBartPreTrainedModel): ...@@ -1634,7 +1651,7 @@ class MBartForCausalLM(MBartPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
...@@ -1667,18 +1684,17 @@ class MBartForCausalLM(MBartPreTrainedModel): ...@@ -1667,18 +1684,17 @@ class MBartForCausalLM(MBartPreTrainedModel):
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -1735,7 +1751,7 @@ class MBartForCausalLM(MBartPreTrainedModel): ...@@ -1735,7 +1751,7 @@ class MBartForCausalLM(MBartPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
......
...@@ -313,7 +313,7 @@ class PegasusEncoderLayer(nn.Module): ...@@ -313,7 +313,7 @@ class PegasusEncoderLayer(nn.Module):
attention_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail. returned tensors for more detail.
...@@ -386,7 +386,7 @@ class PegasusDecoderLayer(nn.Module): ...@@ -386,7 +386,7 @@ class PegasusDecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
...@@ -400,9 +400,9 @@ class PegasusDecoderLayer(nn.Module): ...@@ -400,9 +400,9 @@ class PegasusDecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
...@@ -438,7 +438,7 @@ class PegasusDecoderLayer(nn.Module): ...@@ -438,7 +438,7 @@ class PegasusDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -566,18 +566,25 @@ PEGASUS_INPUTS_DOCSTRING = r""" ...@@ -566,18 +566,25 @@ PEGASUS_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
...@@ -679,11 +686,11 @@ class PegasusEncoder(PegasusPreTrainedModel): ...@@ -679,11 +686,11 @@ class PegasusEncoder(PegasusPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
...@@ -848,7 +855,7 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -848,7 +855,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
...@@ -885,18 +892,18 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -885,18 +892,18 @@ class PegasusDecoder(PegasusPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules in decoder to avoid performing
on hidden heads. Mask values selected in ``[0, 1]``: cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -965,11 +972,12 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -965,11 +972,12 @@ class PegasusDecoder(PegasusPreTrainedModel):
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
...@@ -1003,7 +1011,7 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -1003,7 +1011,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
...@@ -1014,7 +1022,9 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -1014,7 +1022,9 @@ class PegasusDecoder(PegasusPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
...@@ -1092,6 +1102,7 @@ class PegasusModel(PegasusPreTrainedModel): ...@@ -1092,6 +1102,7 @@ class PegasusModel(PegasusPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1150,7 +1161,7 @@ class PegasusModel(PegasusPreTrainedModel): ...@@ -1150,7 +1161,7 @@ class PegasusModel(PegasusPreTrainedModel):
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
...@@ -1232,6 +1243,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): ...@@ -1232,6 +1243,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1267,6 +1279,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): ...@@ -1267,6 +1279,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -1390,7 +1403,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel): ...@@ -1390,7 +1403,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
...@@ -1423,18 +1436,17 @@ class PegasusForCausalLM(PegasusPreTrainedModel): ...@@ -1423,18 +1436,17 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -1491,7 +1503,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel): ...@@ -1491,7 +1503,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
......
...@@ -451,7 +451,7 @@ class Speech2TextDecoderLayer(nn.Module): ...@@ -451,7 +451,7 @@ class Speech2TextDecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
...@@ -465,9 +465,9 @@ class Speech2TextDecoderLayer(nn.Module): ...@@ -465,9 +465,9 @@ class Speech2TextDecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
:obj:`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. :obj:`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
:obj:`(config.encoder_attention_heads,)`. :obj:`(encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size :obj:`(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
...@@ -503,7 +503,7 @@ class Speech2TextDecoderLayer(nn.Module): ...@@ -503,7 +503,7 @@ class Speech2TextDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -623,19 +623,29 @@ SPEECH_TO_TEXT_INPUTS_DOCSTRING = r""" ...@@ -623,19 +623,29 @@ SPEECH_TO_TEXT_INPUTS_DOCSTRING = r"""
:obj:`past_key_values`). :obj:`past_key_values`).
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default. <<<<<<< HEAD
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
If you want to change padding behavior, you should read
:func:`modeling_speech_to_text._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the
paper <https://arxiv.org/abs/1910.13461>`__ for more information on the default strategy.
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
...@@ -728,11 +738,11 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel): ...@@ -728,11 +738,11 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
...@@ -884,7 +894,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): ...@@ -884,7 +894,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
...@@ -921,18 +931,18 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): ...@@ -921,18 +931,18 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``: on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -1001,12 +1011,12 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): ...@@ -1001,12 +1011,12 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
...@@ -1039,7 +1049,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): ...@@ -1039,7 +1049,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
...@@ -1050,7 +1060,9 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): ...@@ -1050,7 +1060,9 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
...@@ -1127,6 +1139,7 @@ class Speech2TextModel(Speech2TextPreTrainedModel): ...@@ -1127,6 +1139,7 @@ class Speech2TextModel(Speech2TextPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
...@@ -1166,7 +1179,7 @@ class Speech2TextModel(Speech2TextPreTrainedModel): ...@@ -1166,7 +1179,7 @@ class Speech2TextModel(Speech2TextPreTrainedModel):
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
...@@ -1240,6 +1253,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel): ...@@ -1240,6 +1253,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
...@@ -1296,6 +1310,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel): ...@@ -1296,6 +1310,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
......
...@@ -607,7 +607,7 @@ class T5Block(nn.Module): ...@@ -607,7 +607,7 @@ class T5Block(nn.Module):
encoder_attention_mask=None, encoder_attention_mask=None,
encoder_decoder_position_bias=None, encoder_decoder_position_bias=None,
layer_head_mask=None, layer_head_mask=None,
encoder_layer_head_mask=None, cross_attn_layer_head_mask=None,
past_key_value=None, past_key_value=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
...@@ -661,7 +661,7 @@ class T5Block(nn.Module): ...@@ -661,7 +661,7 @@ class T5Block(nn.Module):
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias, position_bias=encoder_decoder_position_bias,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
query_length=query_length, query_length=query_length,
use_cache=use_cache, use_cache=use_cache,
...@@ -846,7 +846,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -846,7 +846,7 @@ class T5Stack(T5PreTrainedModel):
encoder_attention_mask=None, encoder_attention_mask=None,
inputs_embeds=None, inputs_embeds=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -913,7 +913,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -913,7 +913,7 @@ class T5Stack(T5PreTrainedModel):
# Prepare head mask if needed # Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers) head_mask = self.get_head_mask(head_mask, self.config.num_layers)
encoder_head_mask = self.get_head_mask(encoder_head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
present_key_value_states = () if use_cache else None present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -925,7 +925,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -925,7 +925,7 @@ class T5Stack(T5PreTrainedModel):
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
layer_head_mask = head_mask[i] layer_head_mask = head_mask[i]
encoder_layer_head_mask = encoder_head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i]
# Model parallel # Model parallel
if self.model_parallel: if self.model_parallel:
torch.cuda.set_device(hidden_states.device) torch.cuda.set_device(hidden_states.device)
...@@ -942,8 +942,8 @@ class T5Stack(T5PreTrainedModel): ...@@ -942,8 +942,8 @@ class T5Stack(T5PreTrainedModel):
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
if layer_head_mask is not None: if layer_head_mask is not None:
layer_head_mask = layer_head_mask.to(hidden_states.device) layer_head_mask = layer_head_mask.to(hidden_states.device)
if encoder_layer_head_mask is not None: if cross_attn_layer_head_mask is not None:
encoder_layer_head_mask = encoder_layer_head_mask.to(hidden_states.device) cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
...@@ -955,7 +955,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -955,7 +955,7 @@ class T5Stack(T5PreTrainedModel):
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
encoder_layer_head_mask=encoder_layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -1082,12 +1082,19 @@ T5_INPUTS_DOCSTRING = r""" ...@@ -1082,12 +1082,19 @@ T5_INPUTS_DOCSTRING = r"""
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. in the decoder Mask values selected in ``[0, Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0,
1]``: 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
...@@ -1263,6 +1270,7 @@ class T5Model(T5PreTrainedModel): ...@@ -1263,6 +1270,7 @@ class T5Model(T5PreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1338,7 +1346,7 @@ class T5Model(T5PreTrainedModel): ...@@ -1338,7 +1346,7 @@ class T5Model(T5PreTrainedModel):
encoder_hidden_states=hidden_states, encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -1451,6 +1459,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1451,6 +1459,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1551,7 +1560,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1551,7 +1560,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_hidden_states=hidden_states, encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
......
...@@ -1041,10 +1041,11 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m ...@@ -1041,10 +1041,11 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None,
inputs_embeds=None, inputs_embeds=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
labels=None, labels=None,
use_cache=None, use_cache=None,
...@@ -1876,7 +1877,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module): ...@@ -1876,7 +1877,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None, cross_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
...@@ -1890,9 +1891,9 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module): ...@@ -1890,9 +1891,9 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of cross_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(config.encoder_attention_heads,)`. size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
...@@ -1927,7 +1928,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module): ...@@ -1927,7 +1928,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -2070,18 +2071,24 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): ...@@ -2070,18 +2071,24 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
If you want to change padding behavior, you should read :func:`modeling_{{cookiecutter.lowercase_modelname}}._prepare_decoder_inputs` and If you want to change padding behavior, you should read :func:`modeling_{{cookiecutter.lowercase_modelname}}._prepare_decoder_inputs` and
modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
information on the default strategy. information on the default strategy.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
...@@ -2211,10 +2218,11 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model ...@@ -2211,10 +2218,11 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
...@@ -2377,7 +2385,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model ...@@ -2377,7 +2385,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
...@@ -2414,18 +2422,17 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model ...@@ -2414,18 +2422,17 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -2493,12 +2500,12 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model ...@@ -2493,12 +2500,12 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
if head_mask is not None: for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
assert head_mask.size()[0] == ( if attn_mask is not None:
len(self.layers) assert attn_mask.size()[0] == (
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
...@@ -2529,7 +2536,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model ...@@ -2529,7 +2536,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
) )
else: else:
...@@ -2540,7 +2547,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model ...@@ -2540,7 +2547,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), cross_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
...@@ -2621,6 +2628,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna ...@@ -2621,6 +2628,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -2662,7 +2670,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna ...@@ -2662,7 +2670,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
...@@ -2743,6 +2751,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte ...@@ -2743,6 +2751,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -2791,6 +2800,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte ...@@ -2791,6 +2800,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -3124,7 +3134,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m ...@@ -3124,7 +3134,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
...@@ -3157,18 +3167,17 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m ...@@ -3157,18 +3167,17 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -3225,7 +3234,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m ...@@ -3225,7 +3234,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
......
...@@ -55,6 +55,7 @@ def prepare_bart_inputs_dict( ...@@ -55,6 +55,7 @@ def prepare_bart_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
...@@ -64,6 +65,8 @@ def prepare_bart_inputs_dict( ...@@ -64,6 +65,8 @@ def prepare_bart_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
...@@ -71,6 +74,7 @@ def prepare_bart_inputs_dict( ...@@ -71,6 +74,7 @@ def prepare_bart_inputs_dict(
"decoder_attention_mask": attention_mask, "decoder_attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }
......
...@@ -45,6 +45,7 @@ def prepare_blenderbot_inputs_dict( ...@@ -45,6 +45,7 @@ def prepare_blenderbot_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
...@@ -54,6 +55,8 @@ def prepare_blenderbot_inputs_dict( ...@@ -54,6 +55,8 @@ def prepare_blenderbot_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
...@@ -61,6 +64,7 @@ def prepare_blenderbot_inputs_dict( ...@@ -61,6 +64,7 @@ def prepare_blenderbot_inputs_dict(
"decoder_attention_mask": attention_mask, "decoder_attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }
......
...@@ -50,6 +50,7 @@ def prepare_blenderbot_small_inputs_dict( ...@@ -50,6 +50,7 @@ def prepare_blenderbot_small_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
...@@ -59,6 +60,8 @@ def prepare_blenderbot_small_inputs_dict( ...@@ -59,6 +60,8 @@ def prepare_blenderbot_small_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
...@@ -66,6 +69,7 @@ def prepare_blenderbot_small_inputs_dict( ...@@ -66,6 +69,7 @@ def prepare_blenderbot_small_inputs_dict(
"decoder_attention_mask": attention_mask, "decoder_attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }
......
...@@ -225,8 +225,8 @@ class ModelTesterMixin: ...@@ -225,8 +225,8 @@ class ModelTesterMixin:
"decoder_attention_mask", "decoder_attention_mask",
] ]
expected_arg_names.extend( expected_arg_names.extend(
["head_mask", "decoder_head_mask", "encoder_outputs"] ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
if "head_mask" and "decoder_head_mask" in arg_names if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
else ["encoder_outputs"] else ["encoder_outputs"]
) )
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
...@@ -492,6 +492,8 @@ class ModelTesterMixin: ...@@ -492,6 +492,8 @@ class ModelTesterMixin:
arg_names = [*signature.parameters.keys()] arg_names = [*signature.parameters.keys()]
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
inputs["decoder_head_mask"] = head_mask inputs["decoder_head_mask"] = head_mask
if "cross_attn_head_mask" in arg_names:
inputs["cross_attn_head_mask"] = head_mask
outputs = model(**inputs, return_dict=True) outputs = model(**inputs, return_dict=True)
# Test that we can get a gradient back for importance score computation # Test that we can get a gradient back for importance score computation
...@@ -523,6 +525,7 @@ class ModelTesterMixin: ...@@ -523,6 +525,7 @@ class ModelTesterMixin:
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
check_attentions_validity(outputs.encoder_attentions) check_attentions_validity(outputs.encoder_attentions)
check_attentions_validity(outputs.decoder_attentions) check_attentions_validity(outputs.decoder_attentions)
check_attentions_validity(outputs.cross_attentions)
else: else:
check_attentions_validity(outputs.attentions) check_attentions_validity(outputs.attentions)
...@@ -1093,7 +1096,7 @@ class ModelTesterMixin: ...@@ -1093,7 +1096,7 @@ class ModelTesterMixin:
# some params shouldn't be scattered by nn.DataParallel # some params shouldn't be scattered by nn.DataParallel
# so just remove them if they are present. # so just remove them if they are present.
blacklist_non_batched_params = ["head_mask", "decoder_head_mask"] blacklist_non_batched_params = ["head_mask", "decoder_head_mask", "cross_attn_head_mask"]
for k in blacklist_non_batched_params: for k in blacklist_non_batched_params:
inputs_dict.pop(k, None) inputs_dict.pop(k, None)
......
...@@ -113,6 +113,7 @@ def prepare_fsmt_inputs_dict( ...@@ -113,6 +113,7 @@ def prepare_fsmt_inputs_dict(
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
...@@ -120,6 +121,8 @@ def prepare_fsmt_inputs_dict( ...@@ -120,6 +121,8 @@ def prepare_fsmt_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
......
...@@ -52,6 +52,7 @@ def prepare_led_inputs_dict( ...@@ -52,6 +52,7 @@ def prepare_led_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
...@@ -61,6 +62,8 @@ def prepare_led_inputs_dict( ...@@ -61,6 +62,8 @@ def prepare_led_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
...@@ -68,6 +71,7 @@ def prepare_led_inputs_dict( ...@@ -68,6 +71,7 @@ def prepare_led_inputs_dict(
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }
......
...@@ -41,16 +41,28 @@ def prepare_m2m_100_inputs_dict( ...@@ -41,16 +41,28 @@ def prepare_m2m_100_inputs_dict(
decoder_input_ids, decoder_input_ids,
attention_mask=None, attention_mask=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
if decoder_attention_mask is None: if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": attention_mask, "decoder_attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }
...@@ -142,9 +154,10 @@ class M2M100ModelTester: ...@@ -142,9 +154,10 @@ class M2M100ModelTester:
model = M2M100Model(config=config).get_decoder().to(torch_device).eval() model = M2M100Model(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["input_ids"] input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"] attention_mask = inputs_dict["attention_mask"]
head_mask = inputs_dict["head_mask"]
# first forward pass # first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple() output, past_key_values = outputs.to_tuple()
...@@ -217,7 +230,6 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase ...@@ -217,7 +230,6 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_head_masking = False
test_missing_keys = False test_missing_keys = False
def setUp(self): def setUp(self):
......
...@@ -60,6 +60,7 @@ def prepare_marian_inputs_dict( ...@@ -60,6 +60,7 @@ def prepare_marian_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
...@@ -69,6 +70,8 @@ def prepare_marian_inputs_dict( ...@@ -69,6 +70,8 @@ def prepare_marian_inputs_dict(
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
...@@ -76,6 +79,7 @@ def prepare_marian_inputs_dict( ...@@ -76,6 +79,7 @@ def prepare_marian_inputs_dict(
"decoder_attention_mask": attention_mask, "decoder_attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment