Unverified Commit 2ebbbf55 authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Add separated decoder_head_mask for T5 Models (#9634)

* Add decoder_head_mask for PyTorch T5 model

* Add decoder_head_mask args into T5Model and T5ForConditionalGeneration

* Slightly change the order of input args to be in accordance
with the convention from BART-based models introduced within the PR #9569.

* Make style for modeling_t5.py

* Add decoder_head_mask for TF T5 models

* Separate head_mask and decoder_head_mask args in TF T5 models

* Slightly change the order of input args to follow convention
of BART-based models updated in PR #9569

* Update test_forward_signature tests/test_modeling_tf_common.py
w.r.t. the changed order of input args

* Add FutureWarnings for T5 and TFT5 models

* Add FutureWarnings for T5 and TFT5 models warning a user that
input argument `head_mask` was split into two arguments -
`head_mask` and `decoder_head_mask`

* Add default behaviour - `decoder_head_mask` is set to copy
`head_mask`

* Fix T5 modeling and FutureWarning

* Make proper usage of head_mask and decoder_head_mask
in cross_attention

* Fix conditions for raising FutureWarning

* Reformat FutureWarning in T5 modeling

* Refactor the warning message
parent e4c06ed6
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import copy import copy
import math import math
import os import os
import warnings
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -409,7 +410,7 @@ class T5Attention(nn.Module): ...@@ -409,7 +410,7 @@ class T5Attention(nn.Module):
key_value_states=None, key_value_states=None,
position_bias=None, position_bias=None,
past_key_value=None, past_key_value=None,
head_mask=None, layer_head_mask=None,
query_length=None, query_length=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
...@@ -504,8 +505,8 @@ class T5Attention(nn.Module): ...@@ -504,8 +505,8 @@ class T5Attention(nn.Module):
) # (batch_size, n_heads, seq_length, key_length) ) # (batch_size, n_heads, seq_length, key_length)
# Mask heads if we want to # Mask heads if we want to
if head_mask is not None: if layer_head_mask is not None:
attn_weights = attn_weights * head_mask attn_weights = attn_weights * layer_head_mask
attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
attn_output = self.o(attn_output) attn_output = self.o(attn_output)
...@@ -530,7 +531,7 @@ class T5LayerSelfAttention(nn.Module): ...@@ -530,7 +531,7 @@ class T5LayerSelfAttention(nn.Module):
hidden_states, hidden_states,
attention_mask=None, attention_mask=None,
position_bias=None, position_bias=None,
head_mask=None, layer_head_mask=None,
past_key_value=None, past_key_value=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
...@@ -540,7 +541,7 @@ class T5LayerSelfAttention(nn.Module): ...@@ -540,7 +541,7 @@ class T5LayerSelfAttention(nn.Module):
normed_hidden_states, normed_hidden_states,
mask=attention_mask, mask=attention_mask,
position_bias=position_bias, position_bias=position_bias,
head_mask=head_mask, layer_head_mask=layer_head_mask,
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -563,7 +564,7 @@ class T5LayerCrossAttention(nn.Module): ...@@ -563,7 +564,7 @@ class T5LayerCrossAttention(nn.Module):
key_value_states, key_value_states,
attention_mask=None, attention_mask=None,
position_bias=None, position_bias=None,
head_mask=None, layer_head_mask=None,
past_key_value=None, past_key_value=None,
use_cache=False, use_cache=False,
query_length=None, query_length=None,
...@@ -575,7 +576,7 @@ class T5LayerCrossAttention(nn.Module): ...@@ -575,7 +576,7 @@ class T5LayerCrossAttention(nn.Module):
mask=attention_mask, mask=attention_mask,
key_value_states=key_value_states, key_value_states=key_value_states,
position_bias=position_bias, position_bias=position_bias,
head_mask=head_mask, layer_head_mask=layer_head_mask,
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
query_length=query_length, query_length=query_length,
...@@ -605,7 +606,8 @@ class T5Block(nn.Module): ...@@ -605,7 +606,8 @@ class T5Block(nn.Module):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
encoder_decoder_position_bias=None, encoder_decoder_position_bias=None,
head_mask=None, layer_head_mask=None,
encoder_layer_head_mask=None,
past_key_value=None, past_key_value=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
...@@ -632,7 +634,7 @@ class T5Block(nn.Module): ...@@ -632,7 +634,7 @@ class T5Block(nn.Module):
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
position_bias=position_bias, position_bias=position_bias,
head_mask=head_mask, layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value, past_key_value=self_attn_past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -659,7 +661,7 @@ class T5Block(nn.Module): ...@@ -659,7 +661,7 @@ class T5Block(nn.Module):
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias, position_bias=encoder_decoder_position_bias,
head_mask=head_mask, layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
query_length=query_length, query_length=query_length,
use_cache=use_cache, use_cache=use_cache,
...@@ -839,6 +841,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -839,6 +841,7 @@ class T5Stack(T5PreTrainedModel):
encoder_attention_mask=None, encoder_attention_mask=None,
inputs_embeds=None, inputs_embeds=None,
head_mask=None, head_mask=None,
encoder_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -906,6 +909,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -906,6 +909,7 @@ class T5Stack(T5PreTrainedModel):
# Prepare head mask if needed # Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers) head_mask = self.get_head_mask(head_mask, self.config.num_layers)
encoder_head_mask = self.get_head_mask(encoder_head_mask, self.config.num_layers)
present_key_value_states = () if use_cache else None present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -930,6 +934,10 @@ class T5Stack(T5PreTrainedModel): ...@@ -930,6 +934,10 @@ class T5Stack(T5PreTrainedModel):
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
if encoder_decoder_position_bias is not None: if encoder_decoder_position_bias is not None:
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
if head_mask is not None:
head_mask = head_mask.to(hidden_states.device)
if encoder_head_mask is not None:
encoder_head_mask = encoder_head_mask.to(hidden_states.device)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
...@@ -940,7 +948,8 @@ class T5Stack(T5PreTrainedModel): ...@@ -940,7 +948,8 @@ class T5Stack(T5PreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=head_mask[i], layer_head_mask=head_mask[i],
encoder_layer_head_mask=encoder_head_mask[i] if encoder_head_mask is not None else None,
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -1058,6 +1067,20 @@ T5_INPUTS_DOCSTRING = r""" ...@@ -1058,6 +1067,20 @@ T5_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default.
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. in the decoder Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
...@@ -1069,12 +1092,6 @@ T5_INPUTS_DOCSTRING = r""" ...@@ -1069,12 +1092,6 @@ T5_INPUTS_DOCSTRING = r"""
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
...@@ -1141,6 +1158,14 @@ T5_ENCODER_INPUTS_DOCSTRING = r""" ...@@ -1141,6 +1158,14 @@ T5_ENCODER_INPUTS_DOCSTRING = r"""
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
""" """
# Warning messafe for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""
@add_start_docstrings( @add_start_docstrings(
"The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.", "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
...@@ -1229,9 +1254,10 @@ class T5Model(T5PreTrainedModel): ...@@ -1229,9 +1254,10 @@ class T5Model(T5PreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
head_mask=None,
inputs_embeds=None, inputs_embeds=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
use_cache=None, use_cache=None,
...@@ -1258,6 +1284,12 @@ class T5Model(T5PreTrainedModel): ...@@ -1258,6 +1284,12 @@ class T5Model(T5PreTrainedModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
if self.config.num_layers == self.config.num_decoder_layers:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
if encoder_outputs is None: if encoder_outputs is None:
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
...@@ -1298,7 +1330,8 @@ class T5Model(T5PreTrainedModel): ...@@ -1298,7 +1330,8 @@ class T5Model(T5PreTrainedModel):
past_key_values=past_key_values, past_key_values=past_key_values,
encoder_hidden_states=hidden_states, encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -1409,9 +1442,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1409,9 +1442,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
head_mask=None,
inputs_embeds=None, inputs_embeds=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
labels=None, labels=None,
...@@ -1447,6 +1481,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1447,6 +1481,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
if self.config.num_layers == self.config.num_decoder_layers:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
if encoder_outputs is None: if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed # Convert encoder inputs in embeddings if needed
...@@ -1503,7 +1543,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1503,7 +1543,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
past_key_values=past_key_values, past_key_values=past_key_values,
encoder_hidden_states=hidden_states, encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=head_mask, head_mask=decoder_head_mask,
encoder_head_mask=head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import copy import copy
import itertools import itertools
import math import math
import warnings
from typing import Tuple from typing import Tuple
import tensorflow as tf import tensorflow as tf
...@@ -245,7 +246,7 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -245,7 +246,7 @@ class TFT5Attention(tf.keras.layers.Layer):
key_value_states=None, key_value_states=None,
position_bias=None, position_bias=None,
past_key_value=None, past_key_value=None,
head_mask=None, layer_head_mask=None,
query_length=None, query_length=None,
use_cache=False, use_cache=False,
training=False, training=False,
...@@ -342,8 +343,8 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -342,8 +343,8 @@ class TFT5Attention(tf.keras.layers.Layer):
weights = self.dropout(weights, training=training) # (batch_size, n_heads, query_length, key_length) weights = self.dropout(weights, training=training) # (batch_size, n_heads, query_length, key_length)
# Mask heads if we want to # Mask heads if we want to
if head_mask is not None: if layer_head_mask is not None:
weights = weights * head_mask weights = weights * layer_head_mask
attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head) attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head)
...@@ -373,7 +374,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer): ...@@ -373,7 +374,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
hidden_states, hidden_states,
attention_mask=None, attention_mask=None,
position_bias=None, position_bias=None,
head_mask=None, layer_head_mask=None,
past_key_value=None, past_key_value=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
...@@ -384,7 +385,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer): ...@@ -384,7 +385,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
normed_hidden_states, normed_hidden_states,
mask=attention_mask, mask=attention_mask,
position_bias=position_bias, position_bias=position_bias,
head_mask=head_mask, layer_head_mask=layer_head_mask,
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -412,7 +413,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer): ...@@ -412,7 +413,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
key_value_states, key_value_states,
attention_mask=None, attention_mask=None,
position_bias=None, position_bias=None,
head_mask=None, layer_head_mask=None,
past_key_value=None, past_key_value=None,
query_length=None, query_length=None,
use_cache=False, use_cache=False,
...@@ -425,7 +426,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer): ...@@ -425,7 +426,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
mask=attention_mask, mask=attention_mask,
key_value_states=key_value_states, key_value_states=key_value_states,
position_bias=position_bias, position_bias=position_bias,
head_mask=head_mask, layer_head_mask=layer_head_mask,
past_key_value=past_key_value, past_key_value=past_key_value,
query_length=query_length, query_length=query_length,
use_cache=use_cache, use_cache=use_cache,
...@@ -467,7 +468,8 @@ class TFT5Block(tf.keras.layers.Layer): ...@@ -467,7 +468,8 @@ class TFT5Block(tf.keras.layers.Layer):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
encoder_decoder_position_bias=None, encoder_decoder_position_bias=None,
head_mask=None, layer_head_mask=None,
encoder_layer_head_mask=None,
past_key_value=None, past_key_value=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
...@@ -494,7 +496,7 @@ class TFT5Block(tf.keras.layers.Layer): ...@@ -494,7 +496,7 @@ class TFT5Block(tf.keras.layers.Layer):
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
position_bias=position_bias, position_bias=position_bias,
head_mask=head_mask, layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value, past_key_value=self_attn_past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -516,7 +518,7 @@ class TFT5Block(tf.keras.layers.Layer): ...@@ -516,7 +518,7 @@ class TFT5Block(tf.keras.layers.Layer):
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias, position_bias=encoder_decoder_position_bias,
head_mask=head_mask, layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
query_length=query_length, query_length=query_length,
use_cache=use_cache, use_cache=use_cache,
...@@ -584,6 +586,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -584,6 +586,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_attention_mask=None, encoder_attention_mask=None,
inputs_embeds=None, inputs_embeds=None,
head_mask=None, head_mask=None,
encoder_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -601,6 +604,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -601,6 +604,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -709,6 +713,8 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -709,6 +713,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
assert inputs["head_mask"] is None, "Head mask not supported" assert inputs["head_mask"] is None, "Head mask not supported"
inputs["head_mask"] = [None] * self.num_hidden_layers inputs["head_mask"] = [None] * self.num_hidden_layers
assert inputs["encoder_head_mask"] is None, "Encoder head mask not supported"
inputs["encoder_head_mask"] = [None] * self.num_hidden_layers
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
all_hidden_states = () if inputs["output_hidden_states"] else None all_hidden_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if inputs["output_attentions"] else None
...@@ -727,7 +733,8 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -727,7 +733,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=inputs["head_mask"][i], layer_head_mask=inputs["head_mask"][i],
encoder_layer_head_mask=inputs["encoder_head_mask"][i],
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
...@@ -950,6 +957,20 @@ T5_INPUTS_DOCSTRING = r""" ...@@ -950,6 +957,20 @@ T5_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default.
head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0,
1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(tf.FloatTensor)`, `optional`): encoder_outputs (:obj:`tuple(tuple(tf.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
...@@ -973,12 +994,6 @@ T5_INPUTS_DOCSTRING = r""" ...@@ -973,12 +994,6 @@ T5_INPUTS_DOCSTRING = r"""
If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds`
takes the value of :obj:`inputs_embeds`. takes the value of :obj:`inputs_embeds`.
head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`). decoding (see :obj:`past_key_values`).
...@@ -1037,6 +1052,13 @@ T5_ENCODER_INPUTS_DOCSTRING = r""" ...@@ -1037,6 +1052,13 @@ T5_ENCODER_INPUTS_DOCSTRING = r"""
behaviors between training and evaluation). behaviors between training and evaluation).
""" """
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = tf.ones((num_layers,
num_heads))`.
"""
@add_start_docstrings( @add_start_docstrings(
"The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.", "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
...@@ -1075,9 +1097,10 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1075,9 +1097,10 @@ class TFT5Model(TFT5PreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
head_mask=None,
inputs_embeds=None, inputs_embeds=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
use_cache=None, use_cache=None,
...@@ -1103,6 +1126,11 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1103,6 +1126,11 @@ class TFT5Model(TFT5PreTrainedModel):
""" """
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config, config=self.config,
...@@ -1110,9 +1138,10 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1110,9 +1138,10 @@ class TFT5Model(TFT5PreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
...@@ -1149,7 +1178,8 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1149,7 +1178,8 @@ class TFT5Model(TFT5PreTrainedModel):
encoder_hidden_states=hidden_states, encoder_hidden_states=hidden_states,
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
head_mask=inputs["head_mask"], head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
...@@ -1251,9 +1281,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1251,9 +1281,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
head_mask=None,
inputs_embeds=None, inputs_embeds=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
labels=None, labels=None,
...@@ -1289,6 +1320,11 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1289,6 +1320,11 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
>>> result = model.generate(inputs) >>> result = model.generate(inputs)
""" """
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config, config=self.config,
...@@ -1296,9 +1332,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1296,9 +1332,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
labels=labels, labels=labels,
...@@ -1340,7 +1377,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1340,7 +1377,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
encoder_hidden_states=hidden_states, encoder_hidden_states=hidden_states,
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
head_mask=inputs["head_mask"], head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
......
...@@ -155,9 +155,13 @@ class TFModelTesterMixin: ...@@ -155,9 +155,13 @@ class TFModelTesterMixin:
"attention_mask", "attention_mask",
"decoder_input_ids", "decoder_input_ids",
"decoder_attention_mask", "decoder_attention_mask",
"encoder_outputs",
] ]
self.assertListEqual(arg_names[:5], expected_arg_names) expected_arg_names.extend(
["head_mask", "decoder_head_mask", "encoder_outputs"]
if "head_mask" and "decoder_head_mask" in arg_names
else ["encoder_outputs"]
)
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
else: else:
expected_arg_names = ["input_ids"] expected_arg_names = ["input_ids"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment