Unverified Commit 71bdc076 authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Add head_mask and decoder_head_mask to PyTorch LED (#9856)

* Add {decoder_,}head_mask to LED

* Fix create_custom_forward signatue in encoder

* Add head_mask to longformer

* Add head_mask to longformer to fix dependencies
of LED on Longformer.

* Not working yet

* Add mising one input in longofrmer_modeling.py

* make fix-copies
parent d6217fb3
...@@ -164,6 +164,7 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -164,6 +164,7 @@ class LEDEncoderSelfAttention(nn.Module):
self, self,
hidden_states, hidden_states,
attention_mask=None, attention_mask=None,
layer_head_mask=None,
is_index_masked=None, is_index_masked=None,
is_index_global_attn=None, is_index_global_attn=None,
is_global_attn=None, is_global_attn=None,
...@@ -251,6 +252,12 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -251,6 +252,12 @@ class LEDEncoderSelfAttention(nn.Module):
attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
attn_probs = layer_head_mask.view(1, 1, -1, 1) * attn_probs
# softmax sometimes inserts NaN if all positions are masked, replace them with 0 # softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0) attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
attn_probs = attn_probs.type_as(attn_scores) attn_probs = attn_probs.type_as(attn_scores)
...@@ -288,6 +295,7 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -288,6 +295,7 @@ class LEDEncoderSelfAttention(nn.Module):
global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden( global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
hidden_states=hidden_states, hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices, max_num_global_attn_indices=max_num_global_attn_indices,
layer_head_mask=layer_head_mask,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
...@@ -595,6 +603,7 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -595,6 +603,7 @@ class LEDEncoderSelfAttention(nn.Module):
self, self,
hidden_states, hidden_states,
max_num_global_attn_indices, max_num_global_attn_indices,
layer_head_mask,
is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero, is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero, is_local_index_no_global_attn_nonzero,
...@@ -656,6 +665,18 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -656,6 +665,18 @@ class LEDEncoderSelfAttention(nn.Module):
global_attn_scores, dim=-1, dtype=torch.float32 global_attn_scores, dim=-1, dtype=torch.float32
) # use fp32 for numerical stability ) # use fp32 for numerical stability
# apply layer head masking
if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
global_attn_probs_float = layer_head_mask.view(1, -1, 1, 1) * global_attn_probs_float.view(
batch_size, self.num_heads, max_num_global_attn_indices, seq_len
)
global_attn_probs_float = global_attn_probs_float.view(
batch_size * self.num_heads, max_num_global_attn_indices, seq_len
)
global_attn_probs = F.dropout( global_attn_probs = F.dropout(
global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
) )
...@@ -686,6 +707,7 @@ class LEDEncoderAttention(nn.Module): ...@@ -686,6 +707,7 @@ class LEDEncoderAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
is_index_masked: Optional[torch.Tensor] = None, is_index_masked: Optional[torch.Tensor] = None,
is_index_global_attn: Optional[torch.Tensor] = None, is_index_global_attn: Optional[torch.Tensor] = None,
is_global_attn: Optional[bool] = None, is_global_attn: Optional[bool] = None,
...@@ -696,6 +718,7 @@ class LEDEncoderAttention(nn.Module): ...@@ -696,6 +718,7 @@ class LEDEncoderAttention(nn.Module):
self_outputs = self.longformer_self_attn( self_outputs = self.longformer_self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn, is_global_attn=is_global_attn,
...@@ -744,6 +767,7 @@ class LEDDecoderAttention(nn.Module): ...@@ -744,6 +767,7 @@ class LEDDecoderAttention(nn.Module):
key_value_states: Optional[torch.Tensor] = None, key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
...@@ -810,6 +834,12 @@ class LEDDecoderAttention(nn.Module): ...@@ -810,6 +834,12 @@ class LEDDecoderAttention(nn.Module):
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = F.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions: if output_attentions:
# this operation is a bit akward, but it's required to # this operation is a bit akward, but it's required to
...@@ -859,6 +889,7 @@ class LEDEncoderLayer(nn.Module): ...@@ -859,6 +889,7 @@ class LEDEncoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
layer_head_mask: torch.Tensor,
is_index_masked=None, is_index_masked=None,
is_index_global_attn=None, is_index_global_attn=None,
is_global_attn=None, is_global_attn=None,
...@@ -869,11 +900,14 @@ class LEDEncoderLayer(nn.Module): ...@@ -869,11 +900,14 @@ class LEDEncoderLayer(nn.Module):
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`torch.FloatTensor`): attention mask of size attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`.
""" """
residual = hidden_states residual = hidden_states
attn_outputs = self.self_attn( attn_outputs = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn, is_global_attn=is_global_attn,
...@@ -931,6 +965,8 @@ class LEDDecoderLayer(nn.Module): ...@@ -931,6 +965,8 @@ class LEDDecoderLayer(nn.Module):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
...@@ -943,6 +979,10 @@ class LEDDecoderLayer(nn.Module): ...@@ -943,6 +979,10 @@ class LEDDecoderLayer(nn.Module):
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
size `(config.encoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`): Whether the base model outputs attentions. output_attentions (:obj:`bool`): Whether the base model outputs attentions.
This requires the attentions tensor to be reshaped in this function. This requires the attentions tensor to be reshaped in this function.
...@@ -957,6 +997,7 @@ class LEDDecoderLayer(nn.Module): ...@@ -957,6 +997,7 @@ class LEDDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
past_key_value=self_attn_past_key_value, past_key_value=self_attn_past_key_value,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
...@@ -975,6 +1016,7 @@ class LEDDecoderLayer(nn.Module): ...@@ -975,6 +1016,7 @@ class LEDDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -1155,6 +1197,17 @@ class LEDSeq2SeqModelOutput(ModelOutput): ...@@ -1155,6 +1197,17 @@ class LEDSeq2SeqModelOutput(ModelOutput):
Global attentions weights after the attention softmax, used to compute the weighted average in the Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence. in the sequence.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
""" """
last_hidden_state: torch.FloatTensor = None last_hidden_state: torch.FloatTensor = None
...@@ -1166,6 +1219,8 @@ class LEDSeq2SeqModelOutput(ModelOutput): ...@@ -1166,6 +1219,8 @@ class LEDSeq2SeqModelOutput(ModelOutput):
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
head_mask: Optional[torch.FloatTensor] = None
decoder_head_mask: Optional[torch.FloatTensor] = None
@dataclass @dataclass
...@@ -1221,6 +1276,17 @@ class LEDSeq2SeqLMOutput(ModelOutput): ...@@ -1221,6 +1276,17 @@ class LEDSeq2SeqLMOutput(ModelOutput):
Global attentions weights after the attention softmax, used to compute the weighted average in the Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence. in the sequence.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
...@@ -1233,6 +1299,8 @@ class LEDSeq2SeqLMOutput(ModelOutput): ...@@ -1233,6 +1299,8 @@ class LEDSeq2SeqLMOutput(ModelOutput):
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
head_mask: Optional[torch.FloatTensor] = None
decoder_head_mask: Optional[torch.FloatTensor] = None
@dataclass @dataclass
...@@ -1288,6 +1356,17 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput): ...@@ -1288,6 +1356,17 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput):
Global attentions weights after the attention softmax, used to compute the weighted average in the Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence. in the sequence.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
...@@ -1300,6 +1379,8 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput): ...@@ -1300,6 +1379,8 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput):
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
head_mask: Optional[torch.FloatTensor] = None
decoder_head_mask: Optional[torch.FloatTensor] = None
@dataclass @dataclass
...@@ -1357,6 +1438,17 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput): ...@@ -1357,6 +1438,17 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
Global attentions weights after the attention softmax, used to compute the weighted average in the Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence. in the sequence.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
...@@ -1370,6 +1462,8 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput): ...@@ -1370,6 +1462,8 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
head_mask: Optional[torch.FloatTensor] = None
decoder_head_mask: Optional[torch.FloatTensor] = None
LED_START_DOCSTRING = r""" LED_START_DOCSTRING = r"""
...@@ -1442,6 +1536,17 @@ LED_INPUTS_DOCSTRING = r""" ...@@ -1442,6 +1536,17 @@ LED_INPUTS_DOCSTRING = r"""
- 0 for local attention (a sliding window attention), - 0 for local attention (a sliding window attention),
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
...@@ -1582,6 +1687,7 @@ class LEDEncoder(LEDPreTrainedModel): ...@@ -1582,6 +1687,7 @@ class LEDEncoder(LEDPreTrainedModel):
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None, global_attention_mask=None,
head_mask=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
...@@ -1615,6 +1721,11 @@ class LEDEncoder(LEDPreTrainedModel): ...@@ -1615,6 +1721,11 @@ class LEDEncoder(LEDPreTrainedModel):
- 0 for local attention (a sliding window attention), - 0 for local attention (a sliding window attention),
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
...@@ -1686,7 +1797,12 @@ class LEDEncoder(LEDPreTrainedModel): ...@@ -1686,7 +1797,12 @@ class LEDEncoder(LEDPreTrainedModel):
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_global_attentions = () if (output_attentions and is_global_attn) else None all_global_attentions = () if (output_attentions and is_global_attn) else None
for encoder_layer in self.layers: # check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
...@@ -1707,6 +1823,7 @@ class LEDEncoder(LEDPreTrainedModel): ...@@ -1707,6 +1823,7 @@ class LEDEncoder(LEDPreTrainedModel):
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask[idx] if head_mask is not None else None,
is_index_masked, is_index_masked,
is_index_global_attn, is_index_global_attn,
) )
...@@ -1714,6 +1831,7 @@ class LEDEncoder(LEDPreTrainedModel): ...@@ -1714,6 +1831,7 @@ class LEDEncoder(LEDPreTrainedModel):
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn, is_global_attn=is_global_attn,
...@@ -1787,6 +1905,8 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -1787,6 +1905,8 @@ class LEDDecoder(LEDPreTrainedModel):
global_attention_mask=None, global_attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=None, use_cache=None,
...@@ -1833,6 +1953,19 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -1833,6 +1953,19 @@ class LEDDecoder(LEDPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding. decoding.
...@@ -1910,6 +2043,12 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -1910,6 +2043,12 @@ class LEDDecoder(LEDPreTrainedModel):
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
all_cross_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
...@@ -1942,6 +2081,8 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -1942,6 +2081,8 @@ class LEDDecoder(LEDPreTrainedModel):
combined_attention_mask, combined_attention_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None,
None, None,
) )
else: else:
...@@ -1950,6 +2091,8 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -1950,6 +2091,8 @@ class LEDDecoder(LEDPreTrainedModel):
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
...@@ -2027,6 +2170,8 @@ class LEDModel(LEDPreTrainedModel): ...@@ -2027,6 +2170,8 @@ class LEDModel(LEDPreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
global_attention_mask=None, global_attention_mask=None,
past_key_values=None, past_key_values=None,
...@@ -2049,6 +2194,7 @@ class LEDModel(LEDPreTrainedModel): ...@@ -2049,6 +2194,7 @@ class LEDModel(LEDPreTrainedModel):
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -2069,6 +2215,8 @@ class LEDModel(LEDPreTrainedModel): ...@@ -2069,6 +2215,8 @@ class LEDModel(LEDPreTrainedModel):
attention_mask=decoder_attention_mask, attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
encoder_head_mask=head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
...@@ -2148,6 +2296,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): ...@@ -2148,6 +2296,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
global_attention_mask=None, global_attention_mask=None,
past_key_values=None, past_key_values=None,
...@@ -2198,6 +2348,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): ...@@ -2198,6 +2348,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -2231,7 +2383,14 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): ...@@ -2231,7 +2383,14 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs self,
decoder_input_ids,
past=None,
attention_mask=None,
head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
): ):
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past is not None: if past is not None:
...@@ -2243,6 +2402,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): ...@@ -2243,6 +2402,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
"past_key_values": past, "past_key_values": past,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
...@@ -2290,6 +2450,8 @@ class LEDForSequenceClassification(LEDPreTrainedModel): ...@@ -2290,6 +2450,8 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
global_attention_mask=None, global_attention_mask=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -2320,6 +2482,8 @@ class LEDForSequenceClassification(LEDPreTrainedModel): ...@@ -2320,6 +2482,8 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -2394,6 +2558,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel): ...@@ -2394,6 +2558,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
global_attention_mask=None, global_attention_mask=None,
start_positions=None, start_positions=None,
...@@ -2425,6 +2591,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel): ...@@ -2425,6 +2591,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
......
...@@ -553,6 +553,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -553,6 +553,7 @@ class LongformerSelfAttention(nn.Module):
self, self,
hidden_states, hidden_states,
attention_mask=None, attention_mask=None,
layer_head_mask=None,
is_index_masked=None, is_index_masked=None,
is_index_global_attn=None, is_index_global_attn=None,
is_global_attn=None, is_global_attn=None,
...@@ -640,6 +641,12 @@ class LongformerSelfAttention(nn.Module): ...@@ -640,6 +641,12 @@ class LongformerSelfAttention(nn.Module):
attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
attn_probs = layer_head_mask.view(1, 1, -1, 1) * attn_probs
# softmax sometimes inserts NaN if all positions are masked, replace them with 0 # softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0) attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
attn_probs = attn_probs.type_as(attn_scores) attn_probs = attn_probs.type_as(attn_scores)
...@@ -677,6 +684,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -677,6 +684,7 @@ class LongformerSelfAttention(nn.Module):
global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden( global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
hidden_states=hidden_states, hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices, max_num_global_attn_indices=max_num_global_attn_indices,
layer_head_mask=layer_head_mask,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
...@@ -984,6 +992,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -984,6 +992,7 @@ class LongformerSelfAttention(nn.Module):
self, self,
hidden_states, hidden_states,
max_num_global_attn_indices, max_num_global_attn_indices,
layer_head_mask,
is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero, is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero, is_local_index_no_global_attn_nonzero,
...@@ -1045,6 +1054,18 @@ class LongformerSelfAttention(nn.Module): ...@@ -1045,6 +1054,18 @@ class LongformerSelfAttention(nn.Module):
global_attn_scores, dim=-1, dtype=torch.float32 global_attn_scores, dim=-1, dtype=torch.float32
) # use fp32 for numerical stability ) # use fp32 for numerical stability
# apply layer head masking
if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
global_attn_probs_float = layer_head_mask.view(1, -1, 1, 1) * global_attn_probs_float.view(
batch_size, self.num_heads, max_num_global_attn_indices, seq_len
)
global_attn_probs_float = global_attn_probs_float.view(
batch_size * self.num_heads, max_num_global_attn_indices, seq_len
)
global_attn_probs = F.dropout( global_attn_probs = F.dropout(
global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
) )
...@@ -1109,6 +1130,7 @@ class LongformerAttention(nn.Module): ...@@ -1109,6 +1130,7 @@ class LongformerAttention(nn.Module):
self, self,
hidden_states, hidden_states,
attention_mask=None, attention_mask=None,
layer_head_mask=None,
is_index_masked=None, is_index_masked=None,
is_index_global_attn=None, is_index_global_attn=None,
is_global_attn=None, is_global_attn=None,
...@@ -1117,6 +1139,7 @@ class LongformerAttention(nn.Module): ...@@ -1117,6 +1139,7 @@ class LongformerAttention(nn.Module):
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn, is_global_attn=is_global_attn,
...@@ -1171,6 +1194,7 @@ class LongformerLayer(nn.Module): ...@@ -1171,6 +1194,7 @@ class LongformerLayer(nn.Module):
self, self,
hidden_states, hidden_states,
attention_mask=None, attention_mask=None,
layer_head_mask=None,
is_index_masked=None, is_index_masked=None,
is_index_global_attn=None, is_index_global_attn=None,
is_global_attn=None, is_global_attn=None,
...@@ -1179,6 +1203,7 @@ class LongformerLayer(nn.Module): ...@@ -1179,6 +1203,7 @@ class LongformerLayer(nn.Module):
self_attn_outputs = self.attention( self_attn_outputs = self.attention(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn, is_global_attn=is_global_attn,
...@@ -1209,6 +1234,7 @@ class LongformerEncoder(nn.Module): ...@@ -1209,6 +1234,7 @@ class LongformerEncoder(nn.Module):
self, self,
hidden_states, hidden_states,
attention_mask=None, attention_mask=None,
head_mask=None,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
return_dict=True, return_dict=True,
...@@ -1222,7 +1248,12 @@ class LongformerEncoder(nn.Module): ...@@ -1222,7 +1248,12 @@ class LongformerEncoder(nn.Module):
all_attentions = () if output_attentions else None # All local attentions. all_attentions = () if output_attentions else None # All local attentions.
all_global_attentions = () if (output_attentions and is_global_attn) else None all_global_attentions = () if (output_attentions and is_global_attn) else None
for i, layer_module in enumerate(self.layer): # check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layer)
), f"The head_mask should be specified for {len(self.layer)} layers, but it is for {head_mask.size()[0]}."
for idx, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
...@@ -1238,6 +1269,7 @@ class LongformerEncoder(nn.Module): ...@@ -1238,6 +1269,7 @@ class LongformerEncoder(nn.Module):
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask[idx] if head_mask is not None else None,
is_index_masked, is_index_masked,
is_index_global_attn, is_index_global_attn,
) )
...@@ -1245,6 +1277,7 @@ class LongformerEncoder(nn.Module): ...@@ -1245,6 +1277,7 @@ class LongformerEncoder(nn.Module):
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=head_mask[idx] if head_mask is not None else None,
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn, is_global_attn=is_global_attn,
...@@ -1386,6 +1419,18 @@ LONGFORMER_INPUTS_DOCSTRING = r""" ...@@ -1386,6 +1419,18 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
- 0 for local attention (a sliding window attention), - 0 for local attention (a sliding window attention),
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
1]``: 1]``:
...@@ -1534,6 +1579,7 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -1534,6 +1579,7 @@ class LongformerModel(LongformerPreTrainedModel):
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None, global_attention_mask=None,
head_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1617,6 +1663,7 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -1617,6 +1663,7 @@ class LongformerModel(LongformerPreTrainedModel):
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -1667,6 +1714,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ...@@ -1667,6 +1714,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None, global_attention_mask=None,
head_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1708,6 +1756,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ...@@ -1708,6 +1756,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1767,6 +1816,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel): ...@@ -1767,6 +1816,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None, global_attention_mask=None,
head_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1793,6 +1843,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel): ...@@ -1793,6 +1843,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1871,6 +1922,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel): ...@@ -1871,6 +1922,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None, global_attention_mask=None,
head_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1932,6 +1984,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel): ...@@ -1932,6 +1984,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -2011,6 +2064,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel): ...@@ -2011,6 +2064,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None, global_attention_mask=None,
head_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -2030,6 +2084,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel): ...@@ -2030,6 +2084,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -2101,6 +2156,7 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel): ...@@ -2101,6 +2156,7 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel):
token_type_ids=None, token_type_ids=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None, global_attention_mask=None,
head_mask=None,
labels=None, labels=None,
position_ids=None, position_ids=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -2150,6 +2206,7 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel): ...@@ -2150,6 +2206,7 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel):
token_type_ids=flat_token_type_ids, token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask, attention_mask=flat_attention_mask,
global_attention_mask=flat_global_attention_mask, global_attention_mask=flat_global_attention_mask,
head_mask=head_mask,
inputs_embeds=flat_inputs_embeds, inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
......
...@@ -473,7 +473,6 @@ class ModelTesterMixin: ...@@ -473,7 +473,6 @@ class ModelTesterMixin:
arg_names = [*signature.parameters.keys()] arg_names = [*signature.parameters.keys()]
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
inputs["decoder_head_mask"] = head_mask inputs["decoder_head_mask"] = head_mask
outputs = model(**inputs, return_dict=True) outputs = model(**inputs, return_dict=True)
# Test that we can get a gradient back for importance score computation # Test that we can get a gradient back for importance score computation
......
...@@ -49,16 +49,24 @@ def prepare_led_inputs_dict( ...@@ -49,16 +49,24 @@ def prepare_led_inputs_dict(
decoder_input_ids, decoder_input_ids,
attention_mask=None, attention_mask=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
if decoder_attention_mask is None: if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
} }
...@@ -160,9 +168,10 @@ class LEDModelTester: ...@@ -160,9 +168,10 @@ class LEDModelTester:
model = LEDModel(config=config).get_decoder().to(torch_device).eval() model = LEDModel(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["input_ids"] input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"] attention_mask = inputs_dict["attention_mask"]
head_mask = inputs_dict["head_mask"]
# first forward pass # first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple() output, past_key_values = outputs.to_tuple()
...@@ -258,7 +267,6 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -258,7 +267,6 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_generative_model_classes = (LEDForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (LEDForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_head_masking = False
test_missing_keys = False test_missing_keys = False
def setUp(self): def setUp(self):
......
...@@ -273,7 +273,6 @@ class LongformerModelTester: ...@@ -273,7 +273,6 @@ class LongformerModelTester:
@require_torch @require_torch
class LongformerModelTest(ModelTesterMixin, unittest.TestCase): class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False # pruning is not supported test_pruning = False # pruning is not supported
test_headmasking = False # head masking is not supported
test_torchscript = False test_torchscript = False
all_model_classes = ( all_model_classes = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment