Unverified Commit 71bdc076 authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Add head_mask and decoder_head_mask to PyTorch LED (#9856)

* Add {decoder_,}head_mask to LED

* Fix create_custom_forward signatue in encoder

* Add head_mask to longformer

* Add head_mask to longformer to fix dependencies
of LED on Longformer.

* Not working yet

* Add mising one input in longofrmer_modeling.py

* make fix-copies
parent d6217fb3
This diff is collapsed.
...@@ -553,6 +553,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -553,6 +553,7 @@ class LongformerSelfAttention(nn.Module):
self, self,
hidden_states, hidden_states,
attention_mask=None, attention_mask=None,
layer_head_mask=None,
is_index_masked=None, is_index_masked=None,
is_index_global_attn=None, is_index_global_attn=None,
is_global_attn=None, is_global_attn=None,
...@@ -640,6 +641,12 @@ class LongformerSelfAttention(nn.Module): ...@@ -640,6 +641,12 @@ class LongformerSelfAttention(nn.Module):
attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
attn_probs = layer_head_mask.view(1, 1, -1, 1) * attn_probs
# softmax sometimes inserts NaN if all positions are masked, replace them with 0 # softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0) attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
attn_probs = attn_probs.type_as(attn_scores) attn_probs = attn_probs.type_as(attn_scores)
...@@ -677,6 +684,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -677,6 +684,7 @@ class LongformerSelfAttention(nn.Module):
global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden( global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
hidden_states=hidden_states, hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices, max_num_global_attn_indices=max_num_global_attn_indices,
layer_head_mask=layer_head_mask,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
...@@ -984,6 +992,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -984,6 +992,7 @@ class LongformerSelfAttention(nn.Module):
self, self,
hidden_states, hidden_states,
max_num_global_attn_indices, max_num_global_attn_indices,
layer_head_mask,
is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero, is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero, is_local_index_no_global_attn_nonzero,
...@@ -1045,6 +1054,18 @@ class LongformerSelfAttention(nn.Module): ...@@ -1045,6 +1054,18 @@ class LongformerSelfAttention(nn.Module):
global_attn_scores, dim=-1, dtype=torch.float32 global_attn_scores, dim=-1, dtype=torch.float32
) # use fp32 for numerical stability ) # use fp32 for numerical stability
# apply layer head masking
if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
global_attn_probs_float = layer_head_mask.view(1, -1, 1, 1) * global_attn_probs_float.view(
batch_size, self.num_heads, max_num_global_attn_indices, seq_len
)
global_attn_probs_float = global_attn_probs_float.view(
batch_size * self.num_heads, max_num_global_attn_indices, seq_len
)
global_attn_probs = F.dropout( global_attn_probs = F.dropout(
global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
) )
...@@ -1109,6 +1130,7 @@ class LongformerAttention(nn.Module): ...@@ -1109,6 +1130,7 @@ class LongformerAttention(nn.Module):
self, self,
hidden_states, hidden_states,
attention_mask=None, attention_mask=None,
layer_head_mask=None,
is_index_masked=None, is_index_masked=None,
is_index_global_attn=None, is_index_global_attn=None,
is_global_attn=None, is_global_attn=None,
...@@ -1117,6 +1139,7 @@ class LongformerAttention(nn.Module): ...@@ -1117,6 +1139,7 @@ class LongformerAttention(nn.Module):
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn, is_global_attn=is_global_attn,
...@@ -1171,6 +1194,7 @@ class LongformerLayer(nn.Module): ...@@ -1171,6 +1194,7 @@ class LongformerLayer(nn.Module):
self, self,
hidden_states, hidden_states,
attention_mask=None, attention_mask=None,
layer_head_mask=None,
is_index_masked=None, is_index_masked=None,
is_index_global_attn=None, is_index_global_attn=None,
is_global_attn=None, is_global_attn=None,
...@@ -1179,6 +1203,7 @@ class LongformerLayer(nn.Module): ...@@ -1179,6 +1203,7 @@ class LongformerLayer(nn.Module):
self_attn_outputs = self.attention( self_attn_outputs = self.attention(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn, is_global_attn=is_global_attn,
...@@ -1209,6 +1234,7 @@ class LongformerEncoder(nn.Module): ...@@ -1209,6 +1234,7 @@ class LongformerEncoder(nn.Module):
self, self,
hidden_states, hidden_states,
attention_mask=None, attention_mask=None,
head_mask=None,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
return_dict=True, return_dict=True,
...@@ -1222,7 +1248,12 @@ class LongformerEncoder(nn.Module): ...@@ -1222,7 +1248,12 @@ class LongformerEncoder(nn.Module):
all_attentions = () if output_attentions else None # All local attentions. all_attentions = () if output_attentions else None # All local attentions.
all_global_attentions = () if (output_attentions and is_global_attn) else None all_global_attentions = () if (output_attentions and is_global_attn) else None
for i, layer_module in enumerate(self.layer): # check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layer)
), f"The head_mask should be specified for {len(self.layer)} layers, but it is for {head_mask.size()[0]}."
for idx, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
...@@ -1238,6 +1269,7 @@ class LongformerEncoder(nn.Module): ...@@ -1238,6 +1269,7 @@ class LongformerEncoder(nn.Module):
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask[idx] if head_mask is not None else None,
is_index_masked, is_index_masked,
is_index_global_attn, is_index_global_attn,
) )
...@@ -1245,6 +1277,7 @@ class LongformerEncoder(nn.Module): ...@@ -1245,6 +1277,7 @@ class LongformerEncoder(nn.Module):
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=head_mask[idx] if head_mask is not None else None,
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn, is_global_attn=is_global_attn,
...@@ -1386,6 +1419,18 @@ LONGFORMER_INPUTS_DOCSTRING = r""" ...@@ -1386,6 +1419,18 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
- 0 for local attention (a sliding window attention), - 0 for local attention (a sliding window attention),
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
1]``: 1]``:
...@@ -1534,6 +1579,7 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -1534,6 +1579,7 @@ class LongformerModel(LongformerPreTrainedModel):
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None, global_attention_mask=None,
head_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1617,6 +1663,7 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -1617,6 +1663,7 @@ class LongformerModel(LongformerPreTrainedModel):
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -1667,6 +1714,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ...@@ -1667,6 +1714,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None, global_attention_mask=None,
head_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1708,6 +1756,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ...@@ -1708,6 +1756,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1767,6 +1816,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel): ...@@ -1767,6 +1816,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None, global_attention_mask=None,
head_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1793,6 +1843,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel): ...@@ -1793,6 +1843,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1871,6 +1922,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel): ...@@ -1871,6 +1922,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None, global_attention_mask=None,
head_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1932,6 +1984,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel): ...@@ -1932,6 +1984,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -2011,6 +2064,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel): ...@@ -2011,6 +2064,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None, global_attention_mask=None,
head_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -2030,6 +2084,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel): ...@@ -2030,6 +2084,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -2101,6 +2156,7 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel): ...@@ -2101,6 +2156,7 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel):
token_type_ids=None, token_type_ids=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None, global_attention_mask=None,
head_mask=None,
labels=None, labels=None,
position_ids=None, position_ids=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -2150,6 +2206,7 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel): ...@@ -2150,6 +2206,7 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel):
token_type_ids=flat_token_type_ids, token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask, attention_mask=flat_attention_mask,
global_attention_mask=flat_global_attention_mask, global_attention_mask=flat_global_attention_mask,
head_mask=head_mask,
inputs_embeds=flat_inputs_embeds, inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
......
...@@ -473,7 +473,6 @@ class ModelTesterMixin: ...@@ -473,7 +473,6 @@ class ModelTesterMixin:
arg_names = [*signature.parameters.keys()] arg_names = [*signature.parameters.keys()]
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
inputs["decoder_head_mask"] = head_mask inputs["decoder_head_mask"] = head_mask
outputs = model(**inputs, return_dict=True) outputs = model(**inputs, return_dict=True)
# Test that we can get a gradient back for importance score computation # Test that we can get a gradient back for importance score computation
......
...@@ -49,16 +49,24 @@ def prepare_led_inputs_dict( ...@@ -49,16 +49,24 @@ def prepare_led_inputs_dict(
decoder_input_ids, decoder_input_ids,
attention_mask=None, attention_mask=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
if decoder_attention_mask is None: if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
} }
...@@ -160,9 +168,10 @@ class LEDModelTester: ...@@ -160,9 +168,10 @@ class LEDModelTester:
model = LEDModel(config=config).get_decoder().to(torch_device).eval() model = LEDModel(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["input_ids"] input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"] attention_mask = inputs_dict["attention_mask"]
head_mask = inputs_dict["head_mask"]
# first forward pass # first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple() output, past_key_values = outputs.to_tuple()
...@@ -258,7 +267,6 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -258,7 +267,6 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_generative_model_classes = (LEDForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (LEDForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_head_masking = False
test_missing_keys = False test_missing_keys = False
def setUp(self): def setUp(self):
......
...@@ -273,7 +273,6 @@ class LongformerModelTester: ...@@ -273,7 +273,6 @@ class LongformerModelTester:
@require_torch @require_torch
class LongformerModelTest(ModelTesterMixin, unittest.TestCase): class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False # pruning is not supported test_pruning = False # pruning is not supported
test_headmasking = False # head masking is not supported
test_torchscript = False test_torchscript = False
all_model_classes = ( all_model_classes = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment