Unverified Commit f45cb66b authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Add head_mask, decoder_head_mask, cross_head_mask to ProphetNet (#9964)

* Add head_mask & decoder_head_mask + some corrections

* Fix head masking for N-grams

* Enable test_headmasking for encoder and decod

* Fix one typo regarding in modeling_propgetnet.py

* Enable test_headmasking for ProphetNetStandaloneDecoderModelTest
and ProphetNetStandaloneEncoderModelTest in test_modeling_prophetnet.py

* make style

* Fix cross_head_mask

* Fix attention head mask naming

* `cross_head_mask` -> `cross_attn_head_mask`

* `cross_layer_head_mask` -> `cross_attn_layer_head_mask`

* Still need to merge #10605 to master to pass the tests
parent 52166f67
...@@ -104,6 +104,24 @@ PROPHETNET_INPUTS_DOCSTRING = r""" ...@@ -104,6 +104,24 @@ PROPHETNET_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default.
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
...@@ -146,6 +164,12 @@ PROPHETNET_STANDALONE_INPUTS_DOCSTRING = r""" ...@@ -146,6 +164,12 @@ PROPHETNET_STANDALONE_INPUTS_DOCSTRING = r"""
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (:obj:`bool`, `optional`): output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail. tensors for more detail.
...@@ -633,6 +657,7 @@ class ProphetNetAttention(nn.Module): ...@@ -633,6 +657,7 @@ class ProphetNetAttention(nn.Module):
hidden_states, hidden_states,
key_value_states: Optional[Tensor] = None, key_value_states: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
layer_head_mask: Optional[Tensor] = None,
past_key_value: Optional[Tuple[Tensor]] = None, past_key_value: Optional[Tuple[Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
...@@ -708,6 +733,19 @@ class ProphetNetAttention(nn.Module): ...@@ -708,6 +733,19 @@ class ProphetNetAttention(nn.Module):
attn_weights_reshaped = None attn_weights_reshaped = None
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = F.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_attn_heads,
), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}"
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
batch_size, self.num_attn_heads, tgt_len, src_len
)
attn_weights = attn_weights.view(batch_size * self.num_attn_heads, tgt_len, src_len)
# apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model
attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped
attn_probs = F.dropout( attn_probs = F.dropout(
attn_weights, attn_weights,
p=self.attention_dropout, p=self.attention_dropout,
...@@ -797,6 +835,7 @@ class ProphetNetNgramSelfAttention(nn.Module): ...@@ -797,6 +835,7 @@ class ProphetNetNgramSelfAttention(nn.Module):
hidden_states, hidden_states,
past_key_value: Optional[Tuple[Tensor]] = None, past_key_value: Optional[Tuple[Tensor]] = None,
attention_mask=None, attention_mask=None,
layer_head_mask=None,
extended_predict_attention_mask=None, extended_predict_attention_mask=None,
main_relative_position_buckets=None, main_relative_position_buckets=None,
predict_relative_position_buckets=None, predict_relative_position_buckets=None,
...@@ -876,6 +915,15 @@ class ProphetNetNgramSelfAttention(nn.Module): ...@@ -876,6 +915,15 @@ class ProphetNetNgramSelfAttention(nn.Module):
onnx_trace=self.onnx_trace, onnx_trace=self.onnx_trace,
).type_as(main_attn_weights) ).type_as(main_attn_weights)
if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_attn_heads,
), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}"
main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view(
batch_size, self.num_attn_heads, -1, sequence_length
)
main_attn_probs = main_attn_probs.view(batch_size * self.num_attn_heads, -1, sequence_length)
main_attn_probs = F.dropout(main_attn_probs, p=self.attention_dropout, training=self.training) main_attn_probs = F.dropout(main_attn_probs, p=self.attention_dropout, training=self.training)
# project to attn_output # project to attn_output
main_attn_output = torch.bmm(main_attn_probs, main_value_states) main_attn_output = torch.bmm(main_attn_probs, main_value_states)
...@@ -929,6 +977,18 @@ class ProphetNetNgramSelfAttention(nn.Module): ...@@ -929,6 +977,18 @@ class ProphetNetNgramSelfAttention(nn.Module):
dim=-1, dim=-1,
onnx_trace=self.onnx_trace, onnx_trace=self.onnx_trace,
).type_as(predict_attn_weights) ).type_as(predict_attn_weights)
if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_attn_heads,
), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}"
predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs.view(
self.ngram, batch_size, self.num_attn_heads, sequence_length, 2 * sequence_length
)
predict_attn_probs = predict_attn_probs.view(
self.ngram, batch_size * self.num_attn_heads, sequence_length, 2 * sequence_length
)
predict_attn_probs = F.dropout(predict_attn_probs, p=self.attention_dropout, training=self.training) predict_attn_probs = F.dropout(predict_attn_probs, p=self.attention_dropout, training=self.training)
# project to attention output # project to attention output
# [ngram, B*head, T, c] # [ngram, B*head, T, c]
...@@ -1063,11 +1123,18 @@ class ProphetNetEncoderLayer(nn.Module): ...@@ -1063,11 +1123,18 @@ class ProphetNetEncoderLayer(nn.Module):
self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim) self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim)
self.feed_forward_layer_norm = LayerNorm(config.hidden_size) self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
def forward(self, hidden_states, attention_mask, output_attentions: bool = False): def forward(
self,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions: bool = False,
):
# 1st residual block # 1st residual block
attention_output, attn_weights, _ = self.self_attn( attention_output, attn_weights, _ = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = self.self_attn_layer_norm(attention_output + hidden_states) hidden_states = self.self_attn_layer_norm(attention_output + hidden_states)
...@@ -1110,6 +1177,8 @@ class ProphetNetDecoderLayer(nn.Module): ...@@ -1110,6 +1177,8 @@ class ProphetNetDecoderLayer(nn.Module):
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attn_mask=None, encoder_attn_mask=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
extended_predict_attention_mask=None, extended_predict_attention_mask=None,
main_relative_position_buckets=None, main_relative_position_buckets=None,
predict_relative_position_buckets=None, predict_relative_position_buckets=None,
...@@ -1125,6 +1194,7 @@ class ProphetNetDecoderLayer(nn.Module): ...@@ -1125,6 +1194,7 @@ class ProphetNetDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
past_key_value=self_attn_past_key_value, past_key_value=self_attn_past_key_value,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
extended_predict_attention_mask=extended_predict_attention_mask, extended_predict_attention_mask=extended_predict_attention_mask,
main_relative_position_buckets=main_relative_position_buckets, main_relative_position_buckets=main_relative_position_buckets,
predict_relative_position_buckets=predict_relative_position_buckets, predict_relative_position_buckets=predict_relative_position_buckets,
...@@ -1141,6 +1211,7 @@ class ProphetNetDecoderLayer(nn.Module): ...@@ -1141,6 +1211,7 @@ class ProphetNetDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attn_mask, attention_mask=encoder_attn_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -1202,6 +1273,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): ...@@ -1202,6 +1273,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
self, self,
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
head_mask=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
...@@ -1254,7 +1326,12 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): ...@@ -1254,7 +1326,12 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
encoder_hidden_states = () if output_hidden_states else None encoder_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
for encoder_layer in self.layers: # check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_hidden_states = encoder_hidden_states + (hidden_states,) encoder_hidden_states = encoder_hidden_states + (hidden_states,)
...@@ -1270,10 +1347,14 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): ...@@ -1270,10 +1347,14 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
extended_attention_mask, extended_attention_mask,
(head_mask[idx] if head_mask is not None else None),
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
hidden_states, attention_mask=extended_attention_mask, output_attentions=output_attentions hidden_states,
attention_mask=extended_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
...@@ -1338,6 +1419,8 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1338,6 +1419,8 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
...@@ -1352,6 +1435,12 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1352,6 +1435,12 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
...@@ -1460,6 +1549,12 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1460,6 +1549,12 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
all_cross_attns = () if output_attentions and self.config.add_cross_attention else None all_cross_attns = () if output_attentions and self.config.add_cross_attention else None
present_key_values = () if use_cache else None present_key_values = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
assert attn_mask.size()[0] == (
len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
# grad cannot be kept because tensor is sliced # grad cannot be kept because tensor is sliced
...@@ -1491,6 +1586,8 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1491,6 +1586,8 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
extended_attention_mask, extended_attention_mask,
encoder_hidden_states, encoder_hidden_states,
extended_encoder_attention_mask, extended_encoder_attention_mask,
(head_mask[idx] if head_mask is not None else None),
(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
extended_predict_attention_mask, extended_predict_attention_mask,
main_relative_position_buckets, main_relative_position_buckets,
predict_relative_position_buckets, predict_relative_position_buckets,
...@@ -1503,6 +1600,10 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1503,6 +1600,10 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attn_mask=extended_encoder_attention_mask, encoder_attn_mask=extended_encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
extended_predict_attention_mask=extended_predict_attention_mask, extended_predict_attention_mask=extended_predict_attention_mask,
main_relative_position_buckets=main_relative_position_buckets, main_relative_position_buckets=main_relative_position_buckets,
predict_relative_position_buckets=predict_relative_position_buckets, predict_relative_position_buckets=predict_relative_position_buckets,
...@@ -1678,6 +1779,9 @@ class ProphetNetModel(ProphetNetPreTrainedModel): ...@@ -1678,6 +1779,9 @@ class ProphetNetModel(ProphetNetPreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Tuple] = None, encoder_outputs: Optional[Tuple] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1716,6 +1820,7 @@ class ProphetNetModel(ProphetNetPreTrainedModel): ...@@ -1716,6 +1820,7 @@ class ProphetNetModel(ProphetNetPreTrainedModel):
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -1728,6 +1833,8 @@ class ProphetNetModel(ProphetNetPreTrainedModel): ...@@ -1728,6 +1833,8 @@ class ProphetNetModel(ProphetNetPreTrainedModel):
attention_mask=decoder_attention_mask, attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -1785,6 +1892,9 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ...@@ -1785,6 +1892,9 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1828,6 +1938,9 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ...@@ -1828,6 +1938,9 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1902,7 +2015,14 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ...@@ -1902,7 +2015,14 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
return loss return loss
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs self,
decoder_input_ids,
past=None,
attention_mask=None,
head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
): ):
assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation." assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation."
...@@ -1915,6 +2035,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ...@@ -1915,6 +2035,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
"past_key_values": past, "past_key_values": past,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache, "use_cache": use_cache,
} }
...@@ -1985,6 +2106,8 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): ...@@ -1985,6 +2106,8 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
...@@ -2000,6 +2123,12 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): ...@@ -2000,6 +2123,12 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
...@@ -2060,6 +2189,8 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): ...@@ -2060,6 +2189,8 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
...@@ -2123,7 +2254,15 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): ...@@ -2123,7 +2254,15 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
return loss return loss
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs): def prepare_inputs_for_generation(
self,
input_ids,
past=None,
attention_mask=None,
head_mask=None,
use_cache=None,
**kwargs,
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape) attention_mask = input_ids.new_ones(input_ids.shape)
...@@ -2134,6 +2273,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): ...@@ -2134,6 +2273,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
return { return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask,
"past_key_values": past, "past_key_values": past,
"use_cache": use_cache, "use_cache": use_cache,
} }
......
...@@ -891,7 +891,6 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test ...@@ -891,7 +891,6 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
test_pruning = False test_pruning = False
test_torchscript = False test_torchscript = False
test_resize_embeddings = False test_resize_embeddings = False
test_headmasking = False
is_encoder_decoder = True is_encoder_decoder = True
def setUp(self): def setUp(self):
...@@ -1097,7 +1096,6 @@ class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMix ...@@ -1097,7 +1096,6 @@ class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMix
test_pruning = False test_pruning = False
test_torchscript = False test_torchscript = False
test_resize_embeddings = False test_resize_embeddings = False
test_headmasking = False
is_encoder_decoder = False is_encoder_decoder = False
def setUp(self): def setUp(self):
...@@ -1126,7 +1124,6 @@ class ProphetNetStandaloneEncoderModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -1126,7 +1124,6 @@ class ProphetNetStandaloneEncoderModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False test_pruning = False
test_torchscript = False test_torchscript = False
test_resize_embeddings = False test_resize_embeddings = False
test_headmasking = False
is_encoder_decoder = False is_encoder_decoder = False
def setUp(self): def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment