Unverified Commit 2ebbbf55 authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Add separated decoder_head_mask for T5 Models (#9634)

* Add decoder_head_mask for PyTorch T5 model

* Add decoder_head_mask args into T5Model and T5ForConditionalGeneration

* Slightly change the order of input args to be in accordance
with the convention from BART-based models introduced within the PR #9569.

* Make style for modeling_t5.py

* Add decoder_head_mask for TF T5 models

* Separate head_mask and decoder_head_mask args in TF T5 models

* Slightly change the order of input args to follow convention
of BART-based models updated in PR #9569

* Update test_forward_signature tests/test_modeling_tf_common.py
w.r.t. the changed order of input args

* Add FutureWarnings for T5 and TFT5 models

* Add FutureWarnings for T5 and TFT5 models warning a user that
input argument `head_mask` was split into two arguments -
`head_mask` and `decoder_head_mask`

* Add default behaviour - `decoder_head_mask` is set to copy
`head_mask`

* Fix T5 modeling and FutureWarning

* Make proper usage of head_mask and decoder_head_mask
in cross_attention

* Fix conditions for raising FutureWarning

* Reformat FutureWarning in T5 modeling

* Refactor the warning message
parent e4c06ed6
......@@ -18,6 +18,7 @@
import copy
import math
import os
import warnings
import torch
import torch.nn.functional as F
......@@ -409,7 +410,7 @@ class T5Attention(nn.Module):
key_value_states=None,
position_bias=None,
past_key_value=None,
head_mask=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
......@@ -504,8 +505,8 @@ class T5Attention(nn.Module):
) # (batch_size, n_heads, seq_length, key_length)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask
attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
attn_output = self.o(attn_output)
......@@ -530,7 +531,7 @@ class T5LayerSelfAttention(nn.Module):
hidden_states,
attention_mask=None,
position_bias=None,
head_mask=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
......@@ -540,7 +541,7 @@ class T5LayerSelfAttention(nn.Module):
normed_hidden_states,
mask=attention_mask,
position_bias=position_bias,
head_mask=head_mask,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
......@@ -563,7 +564,7 @@ class T5LayerCrossAttention(nn.Module):
key_value_states,
attention_mask=None,
position_bias=None,
head_mask=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
query_length=None,
......@@ -575,7 +576,7 @@ class T5LayerCrossAttention(nn.Module):
mask=attention_mask,
key_value_states=key_value_states,
position_bias=position_bias,
head_mask=head_mask,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
query_length=query_length,
......@@ -605,7 +606,8 @@ class T5Block(nn.Module):
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
head_mask=None,
layer_head_mask=None,
encoder_layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
......@@ -632,7 +634,7 @@ class T5Block(nn.Module):
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
head_mask=head_mask,
layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
......@@ -659,7 +661,7 @@ class T5Block(nn.Module):
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
head_mask=head_mask,
layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value,
query_length=query_length,
use_cache=use_cache,
......@@ -839,6 +841,7 @@ class T5Stack(T5PreTrainedModel):
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
......@@ -906,6 +909,7 @@ class T5Stack(T5PreTrainedModel):
# Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
encoder_head_mask = self.get_head_mask(encoder_head_mask, self.config.num_layers)
present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
......@@ -930,6 +934,10 @@ class T5Stack(T5PreTrainedModel):
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
if encoder_decoder_position_bias is not None:
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
if head_mask is not None:
head_mask = head_mask.to(hidden_states.device)
if encoder_head_mask is not None:
encoder_head_mask = encoder_head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
......@@ -940,7 +948,8 @@ class T5Stack(T5PreTrainedModel):
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=head_mask[i],
layer_head_mask=head_mask[i],
encoder_layer_head_mask=encoder_head_mask[i] if encoder_head_mask is not None else None,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
......@@ -1058,6 +1067,20 @@ T5_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default.
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. in the decoder Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
......@@ -1069,12 +1092,6 @@ T5_INPUTS_DOCSTRING = r"""
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
......@@ -1141,6 +1158,14 @@ T5_ENCODER_INPUTS_DOCSTRING = r"""
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
# Warning messafe for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""
@add_start_docstrings(
"The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
......@@ -1229,9 +1254,10 @@ class T5Model(T5PreTrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None,
past_key_values=None,
head_mask=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
......@@ -1258,6 +1284,12 @@ class T5Model(T5PreTrainedModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
if self.config.num_layers == self.config.num_decoder_layers:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
encoder_outputs = self.encoder(
......@@ -1298,7 +1330,8 @@ class T5Model(T5PreTrainedModel):
past_key_values=past_key_values,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
head_mask=decoder_head_mask,
encoder_head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
......@@ -1409,9 +1442,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None,
past_key_values=None,
head_mask=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
......@@ -1447,6 +1481,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
if self.config.num_layers == self.config.num_decoder_layers:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed
......@@ -1503,7 +1543,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
past_key_values=past_key_values,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
head_mask=decoder_head_mask,
encoder_head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
......
......@@ -18,6 +18,7 @@
import copy
import itertools
import math
import warnings
from typing import Tuple
import tensorflow as tf
......@@ -245,7 +246,7 @@ class TFT5Attention(tf.keras.layers.Layer):
key_value_states=None,
position_bias=None,
past_key_value=None,
head_mask=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
training=False,
......@@ -342,8 +343,8 @@ class TFT5Attention(tf.keras.layers.Layer):
weights = self.dropout(weights, training=training) # (batch_size, n_heads, query_length, key_length)
# Mask heads if we want to
if head_mask is not None:
weights = weights * head_mask
if layer_head_mask is not None:
weights = weights * layer_head_mask
attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head)
......@@ -373,7 +374,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
hidden_states,
attention_mask=None,
position_bias=None,
head_mask=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
......@@ -384,7 +385,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
normed_hidden_states,
mask=attention_mask,
position_bias=position_bias,
head_mask=head_mask,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
......@@ -412,7 +413,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
key_value_states,
attention_mask=None,
position_bias=None,
head_mask=None,
layer_head_mask=None,
past_key_value=None,
query_length=None,
use_cache=False,
......@@ -425,7 +426,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
mask=attention_mask,
key_value_states=key_value_states,
position_bias=position_bias,
head_mask=head_mask,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
query_length=query_length,
use_cache=use_cache,
......@@ -467,7 +468,8 @@ class TFT5Block(tf.keras.layers.Layer):
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
head_mask=None,
layer_head_mask=None,
encoder_layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
......@@ -494,7 +496,7 @@ class TFT5Block(tf.keras.layers.Layer):
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
head_mask=head_mask,
layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
......@@ -516,7 +518,7 @@ class TFT5Block(tf.keras.layers.Layer):
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
head_mask=head_mask,
layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value,
query_length=query_length,
use_cache=use_cache,
......@@ -584,6 +586,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
......@@ -601,6 +604,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_attention_mask=encoder_attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
......@@ -709,6 +713,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
assert inputs["head_mask"] is None, "Head mask not supported"
inputs["head_mask"] = [None] * self.num_hidden_layers
assert inputs["encoder_head_mask"] is None, "Encoder head mask not supported"
inputs["encoder_head_mask"] = [None] * self.num_hidden_layers
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
all_hidden_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None
......@@ -727,7 +733,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=inputs["head_mask"][i],
layer_head_mask=inputs["head_mask"][i],
encoder_layer_head_mask=inputs["encoder_head_mask"][i],
past_key_value=past_key_value,
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
......@@ -950,6 +957,20 @@ T5_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default.
head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(tf.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
......@@ -973,12 +994,6 @@ T5_INPUTS_DOCSTRING = r"""
If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds`
takes the value of :obj:`inputs_embeds`.
head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
......@@ -1037,6 +1052,13 @@ T5_ENCODER_INPUTS_DOCSTRING = r"""
behaviors between training and evaluation).
"""
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = tf.ones((num_layers,
num_heads))`.
"""
@add_start_docstrings(
"The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
......@@ -1075,9 +1097,10 @@ class TFT5Model(TFT5PreTrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None,
past_key_values=None,
head_mask=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
......@@ -1103,6 +1126,11 @@ class TFT5Model(TFT5PreTrainedModel):
"""
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
inputs = input_processing(
func=self.call,
config=self.config,
......@@ -1110,9 +1138,10 @@ class TFT5Model(TFT5PreTrainedModel):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
......@@ -1149,7 +1178,8 @@ class TFT5Model(TFT5PreTrainedModel):
encoder_hidden_states=hidden_states,
encoder_attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["decoder_inputs_embeds"],
head_mask=inputs["head_mask"],
head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
......@@ -1251,9 +1281,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None,
past_key_values=None,
head_mask=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
......@@ -1289,6 +1320,11 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
>>> result = model.generate(inputs)
"""
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
inputs = input_processing(
func=self.call,
config=self.config,
......@@ -1296,9 +1332,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
labels=labels,
......@@ -1340,7 +1377,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
encoder_hidden_states=hidden_states,
encoder_attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["decoder_inputs_embeds"],
head_mask=inputs["head_mask"],
head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
......
......@@ -155,9 +155,13 @@ class TFModelTesterMixin:
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"encoder_outputs",
]
self.assertListEqual(arg_names[:5], expected_arg_names)
expected_arg_names.extend(
["head_mask", "decoder_head_mask", "encoder_outputs"]
if "head_mask" and "decoder_head_mask" in arg_names
else ["encoder_outputs"]
)
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
else:
expected_arg_names = ["input_ids"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment