Unverified Commit 0c6c0afc authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Add head_mask and decoder_head_mask to FSMT (#9819)

* Add {decoder_,}head_mask to fsmt_modeling.py

* Enable test_headmasking and some changes to docs

* Remove test_head_masking flag from fsmt test file

Remove test_head_masking flag from test_modeling_fsmt.py
since test_head_masking is set to be True by default (thus it is redundant to store).

* Merge master and remove test_head_masking = True

* Rebase necessary due to an update of jaxlib

* Remove test_head_masking=True in tests/test_modeling_fsmt.py
as it is redundant.
parent 74f16b82
...@@ -240,6 +240,17 @@ FSMT_INPUTS_DOCSTRING = r""" ...@@ -240,6 +240,17 @@ FSMT_INPUTS_DOCSTRING = r"""
also be used by default. If you want to change padding behavior, you should read also be used by default. If you want to change padding behavior, you should read
:func:`modeling_fstm._prepare_fstm_decoder_inputs` and modify. See diagram 1 in the paper for more info on :func:`modeling_fstm._prepare_fstm_decoder_inputs` and modify. See diagram 1 in the paper for more info on
the default strategy the default strategy
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`Tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`Tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
...@@ -282,7 +293,11 @@ def triu_onnx(x, diagonal=0): ...@@ -282,7 +293,11 @@ def triu_onnx(x, diagonal=0):
def _prepare_fsmt_decoder_inputs( def _prepare_fsmt_decoder_inputs(
config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32 config,
input_ids,
decoder_input_ids=None,
decoder_padding_mask=None,
causal_mask_dtype=torch.float32,
): ):
""" """
Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if none are provided. Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if none are provided.
...@@ -377,21 +392,27 @@ class EncoderLayer(nn.Module): ...@@ -377,21 +392,27 @@ class EncoderLayer(nn.Module):
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim)
def forward(self, x, encoder_padding_mask, output_attentions=False): def forward(self, x, encoder_padding_mask, layer_head_mask, output_attentions=False):
""" """
Args: Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` x (:obj:`torch.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape encoder_padding_mask (:obj:`torch.ByteTensor`): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``. `(batch, src_len)` where padding elements are indicated by ``1``.
for t_tgt, t_src is excluded (or masked out), =0 means it is for t_tgt, t_src is excluded (or masked out), =0 means it is
included in attention included in attention
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`.
Returns: Returns:
encoded output of shape `(seq_len, batch, embed_dim)` encoded output of shape `(seq_len, batch, embed_dim)`
""" """
residual = x residual = x
x, attn_weights = self.self_attn( x, attn_weights = self.self_attn(
query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions query=x,
key=x,
key_padding_mask=encoder_padding_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
...@@ -432,21 +453,32 @@ class FSMTEncoder(nn.Module): ...@@ -432,21 +453,32 @@ class FSMTEncoder(nn.Module):
) # type: List[EncoderLayer] ) # type: List[EncoderLayer]
def forward( def forward(
self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True self,
input_ids,
attention_mask=None,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
): ):
""" """
Args: Args:
input_ids (LongTensor): tokens in the source language of shape input_ids (:obj:`torch.LongTensor`): tokens in the source language of shape
`(batch, src_len)` `(batch, src_len)`
attention_mask (torch.LongTensor): indicating which indices are padding tokens attention_mask (:obj:`torch.LongTensor`): indicating which indices are padding tokens
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
Returns: Returns:
BaseModelOutput or Tuple comprised of: BaseModelOutput or Tuple comprised of:
- **x** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **x** (:obj:`torch.Tensor`): the last encoder layer's output of shape `(src_len, batch, embed_dim)`
- **encoder_states** (tuple(torch.FloatTensor)): all intermediate hidden states of shape `(src_len, - **encoder_states** (:obj:`Tuple(torch.FloatTensor`)): all intermediate hidden states of shape
batch, embed_dim)`. Only populated if *output_hidden_states:* is True. `(src_len, batch, embed_dim)`. Only populated if *output_hidden_states:* is True.
- **all_attentions** (tuple(torch.FloatTensor)): Attention weights for each layer. - **all_attentions** (:obj:`Tuple(torch.FloatTensor`)): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout. During training might not be of length n_layers because of layer dropout.
""" """
# check attention mask and invert # check attention mask and invert
...@@ -463,7 +495,12 @@ class FSMTEncoder(nn.Module): ...@@ -463,7 +495,12 @@ class FSMTEncoder(nn.Module):
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
for encoder_layer in self.layers: # check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
x = x.transpose(0, 1) # T x B x C -> B x T x C x = x.transpose(0, 1) # T x B x C -> B x T x C
encoder_states += (x,) encoder_states += (x,)
...@@ -473,7 +510,12 @@ class FSMTEncoder(nn.Module): ...@@ -473,7 +510,12 @@ class FSMTEncoder(nn.Module):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
attn = None attn = None
else: else:
x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions) x, attn = encoder_layer(
x,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (attn,) all_attentions = all_attentions + (attn,)
...@@ -522,6 +564,8 @@ class DecoderLayer(nn.Module): ...@@ -522,6 +564,8 @@ class DecoderLayer(nn.Module):
encoder_attn_mask=None, encoder_attn_mask=None,
layer_state=None, layer_state=None,
causal_mask=None, causal_mask=None,
layer_head_mask=None,
encoder_layer_head_mask=None,
decoder_padding_mask=None, decoder_padding_mask=None,
output_attentions=False, output_attentions=False,
): ):
...@@ -537,6 +581,7 @@ class DecoderLayer(nn.Module): ...@@ -537,6 +581,7 @@ class DecoderLayer(nn.Module):
layer_state=layer_state, # adds keys to layer state layer_state=layer_state, # adds keys to layer state
key_padding_mask=decoder_padding_mask, key_padding_mask=decoder_padding_mask,
attn_mask=causal_mask, attn_mask=causal_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
...@@ -551,6 +596,7 @@ class DecoderLayer(nn.Module): ...@@ -551,6 +596,7 @@ class DecoderLayer(nn.Module):
key=encoder_hidden_states, key=encoder_hidden_states,
key_padding_mask=encoder_attn_mask, key_padding_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state layer_state=layer_state, # mutates layer state
layer_head_mask=encoder_layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
...@@ -611,6 +657,8 @@ class FSMTDecoder(nn.Module): ...@@ -611,6 +657,8 @@ class FSMTDecoder(nn.Module):
encoder_padding_mask, encoder_padding_mask,
decoder_padding_mask, decoder_padding_mask,
decoder_causal_mask, decoder_causal_mask,
head_mask=None,
encoder_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
...@@ -622,12 +670,24 @@ class FSMTDecoder(nn.Module): ...@@ -622,12 +670,24 @@ class FSMTDecoder(nn.Module):
EMNLP 2019). EMNLP 2019).
Args: Args:
input_ids (LongTensor): previous decoder outputs of shape input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch, tgt_len)`):
`(batch, tgt_len)`, for teacher forcing previous decoder outputs for teacher forcing
encoder_hidden_states: output from the encoder, used for encoder_hidden_states: output from the encoder, used for
encoder-side attention encoder-side attention
encoder_padding_mask: for ignoring pad tokens encoder_padding_mask: for ignoring pad tokens
past_key_values (dict or None): dictionary used for storing state during generation past_key_values (dict or None): dictionary used for storing state during generation
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
Returns: Returns:
BaseModelOutputWithPast or tuple: BaseModelOutputWithPast or tuple:
...@@ -662,6 +722,12 @@ class FSMTDecoder(nn.Module): ...@@ -662,6 +722,12 @@ class FSMTDecoder(nn.Module):
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
all_cross_attns = () if output_attentions else None all_cross_attns = () if output_attentions else None
next_decoder_cache = [] next_decoder_cache = []
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
...@@ -681,6 +747,8 @@ class FSMTDecoder(nn.Module): ...@@ -681,6 +747,8 @@ class FSMTDecoder(nn.Module):
decoder_padding_mask=decoder_padding_mask, decoder_padding_mask=decoder_padding_mask,
layer_state=layer_state, layer_state=layer_state,
causal_mask=decoder_causal_mask, causal_mask=decoder_causal_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -761,6 +829,7 @@ class Attention(nn.Module): ...@@ -761,6 +829,7 @@ class Attention(nn.Module):
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
layer_state: Optional[Dict[str, Optional[Tensor]]] = None, layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
layer_head_mask: Optional[Tensor] = None,
output_attentions=False, output_attentions=False,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time(SeqLen) x Batch x Channel""" """Input shape: Time(SeqLen) x Batch x Channel"""
...@@ -830,6 +899,13 @@ class Attention(nn.Module): ...@@ -830,6 +899,13 @@ class Attention(nn.Module):
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = F.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions: if output_attentions:
# make sure that attn_weights are included in graph # make sure that attn_weights are included in graph
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
...@@ -923,6 +999,8 @@ class FSMTModel(PretrainedFSMTModel): ...@@ -923,6 +999,8 @@ class FSMTModel(PretrainedFSMTModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Tuple] = None, encoder_outputs: Optional[Tuple] = None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
...@@ -958,6 +1036,7 @@ class FSMTModel(PretrainedFSMTModel): ...@@ -958,6 +1036,7 @@ class FSMTModel(PretrainedFSMTModel):
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -977,6 +1056,8 @@ class FSMTModel(PretrainedFSMTModel): ...@@ -977,6 +1056,8 @@ class FSMTModel(PretrainedFSMTModel):
attention_mask, attention_mask,
decoder_padding_mask, decoder_padding_mask,
decoder_causal_mask=causal_mask, decoder_causal_mask=causal_mask,
head_mask=decoder_head_mask,
encoder_head_mask=head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -1052,6 +1133,8 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): ...@@ -1052,6 +1133,8 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
labels=None, labels=None,
...@@ -1080,6 +1163,8 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): ...@@ -1080,6 +1163,8 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
......
...@@ -111,12 +111,20 @@ def prepare_fsmt_inputs_dict( ...@@ -111,12 +111,20 @@ def prepare_fsmt_inputs_dict(
config, config,
input_ids, input_ids,
attention_mask=None, attention_mask=None,
head_mask=None,
decoder_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id) attention_mask = input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
} }
...@@ -126,7 +134,6 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -126,7 +134,6 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_head_masking = False
test_missing_keys = False test_missing_keys = False
def setUp(self): def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment