Unverified Commit 38a716cd authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

TF BART models - Add `cross_attentions` to model output and fix...

TF BART models - Add `cross_attentions` to model output and fix cross-attention head masking (#10699)

* Add cross_attn_head_mask to BART

* Fix cross_attentions in TFBart-like models

* This commit enables returning of `cross_attentions`
for TFBart-like models

* It also fixes attention head masking in cross-attenion module

* Update TF model templates

* Fix missing , in TF model templates

* Fix typo: congig -> config
parent 4bd6b54f
...@@ -116,6 +116,82 @@ class TFBaseModelOutputWithPast(ModelOutput): ...@@ -116,6 +116,82 @@ class TFBaseModelOutputWithPast(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFBaseModelOutputWithCrossAttentions(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(tf.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
"""
last_hidden_state: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
Args:
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
1, hidden_size)` is output.
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size,
num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(tf.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
"""
last_hidden_state: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass @dataclass
class TFSeq2SeqModelOutput(ModelOutput): class TFSeq2SeqModelOutput(ModelOutput):
""" """
...@@ -145,6 +221,12 @@ class TFSeq2SeqModelOutput(ModelOutput): ...@@ -145,6 +221,12 @@ class TFSeq2SeqModelOutput(ModelOutput):
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads. self-attention heads.
cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): encoder_last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model. Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
...@@ -164,6 +246,7 @@ class TFSeq2SeqModelOutput(ModelOutput): ...@@ -164,6 +246,7 @@ class TFSeq2SeqModelOutput(ModelOutput):
past_key_values: Optional[List[tf.Tensor]] = None past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None decoder_attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None encoder_last_hidden_state: Optional[tf.Tensor] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None encoder_attentions: Optional[Tuple[tf.Tensor]] = None
...@@ -290,6 +373,12 @@ class TFSeq2SeqLMOutput(ModelOutput): ...@@ -290,6 +373,12 @@ class TFSeq2SeqLMOutput(ModelOutput):
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads. self-attention heads.
cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): encoder_last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model. Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
...@@ -310,6 +399,7 @@ class TFSeq2SeqLMOutput(ModelOutput): ...@@ -310,6 +399,7 @@ class TFSeq2SeqLMOutput(ModelOutput):
past_key_values: Optional[List[tf.Tensor]] = None past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None decoder_attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None encoder_last_hidden_state: Optional[tf.Tensor] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None encoder_attentions: Optional[Tuple[tf.Tensor]] = None
......
...@@ -30,7 +30,7 @@ from ...file_utils import ( ...@@ -30,7 +30,7 @@ from ...file_utils import (
) )
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPast, TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput, TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput, TFSeq2SeqModelOutput,
) )
...@@ -365,7 +365,7 @@ class TFBartDecoderLayer(tf.keras.layers.Layer): ...@@ -365,7 +365,7 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None, cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
...@@ -379,8 +379,8 @@ class TFBartDecoderLayer(tf.keras.layers.Layer): ...@@ -379,8 +379,8 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)` `(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
`(encoder_attention_heads,)` `(decoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
...@@ -401,16 +401,17 @@ class TFBartDecoderLayer(tf.keras.layers.Layer): ...@@ -401,16 +401,17 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
# Cross-Attention Block # Cross-Attention Block
cross_attn_present_key_value = None cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
residual = hidden_states residual = hidden_states
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, _, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
...@@ -432,6 +433,7 @@ class TFBartDecoderLayer(tf.keras.layers.Layer): ...@@ -432,6 +433,7 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
return ( return (
hidden_states, hidden_states,
self_attn_weights, self_attn_weights,
cross_attn_weights,
present_key_value, present_key_value,
) )
...@@ -572,7 +574,7 @@ BART_INPUTS_DOCSTRING = r""" ...@@ -572,7 +574,7 @@ BART_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
...@@ -580,6 +582,12 @@ BART_INPUTS_DOCSTRING = r""" ...@@ -580,6 +582,12 @@ BART_INPUTS_DOCSTRING = r"""
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
...@@ -677,7 +685,7 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -677,7 +685,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
...@@ -814,7 +822,7 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -814,7 +822,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -856,14 +864,13 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -856,14 +864,13 @@ class TFBartDecoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -894,7 +901,7 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -894,7 +901,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
...@@ -949,17 +956,19 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -949,17 +956,19 @@ class TFBartDecoder(tf.keras.layers.Layer):
# decoder layers # decoder layers
all_hidden_states = () if inputs["output_hidden_states"] else None all_hidden_states = () if inputs["output_hidden_states"] else None
all_self_attns = () if inputs["output_attentions"] else None all_self_attns = () if inputs["output_attentions"] else None
all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
present_key_values = () if inputs["use_cache"] else None present_key_values = () if inputs["use_cache"] else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly(): for attn_mask in ["head_mask", "cross_attn_head_mask"]:
tf.debugging.assert_equal( if inputs[attn_mask] is not None and tf.executing_eagerly():
shape_list(inputs["head_mask"])[0], tf.debugging.assert_equal(
len(self.layers), shape_list(inputs[attn_mask])[0],
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", len(self.layers),
) message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
...@@ -973,14 +982,14 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -973,14 +982,14 @@ class TFBartDecoder(tf.keras.layers.Layer):
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
hidden_states, layer_self_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx] cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
if inputs["encoder_head_mask"] is not None if inputs["cross_attn_head_mask"] is not None
else None, else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
...@@ -991,23 +1000,30 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -991,23 +1000,30 @@ class TFBartDecoder(tf.keras.layers.Layer):
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns += (layer_cross_attn,)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns = list(all_self_attns) all_self_attns = list(all_self_attns)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns = list(all_cross_attns)
if inputs["use_cache"]: if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values) present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_values, past_key_values=present_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
cross_attentions=all_cross_attns,
) )
...@@ -1057,6 +1073,7 @@ class TFBartMainLayer(tf.keras.layers.Layer): ...@@ -1057,6 +1073,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1077,6 +1094,7 @@ class TFBartMainLayer(tf.keras.layers.Layer): ...@@ -1077,6 +1094,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1131,7 +1149,7 @@ class TFBartMainLayer(tf.keras.layers.Layer): ...@@ -1131,7 +1149,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"], head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"], cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
...@@ -1149,6 +1167,7 @@ class TFBartMainLayer(tf.keras.layers.Layer): ...@@ -1149,6 +1167,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
...@@ -1189,6 +1208,7 @@ class TFBartModel(TFBartPretrainedModel): ...@@ -1189,6 +1208,7 @@ class TFBartModel(TFBartPretrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1209,6 +1229,7 @@ class TFBartModel(TFBartPretrainedModel): ...@@ -1209,6 +1229,7 @@ class TFBartModel(TFBartPretrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1228,6 +1249,7 @@ class TFBartModel(TFBartPretrainedModel): ...@@ -1228,6 +1249,7 @@ class TFBartModel(TFBartPretrainedModel):
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
...@@ -1245,6 +1267,7 @@ class TFBartModel(TFBartPretrainedModel): ...@@ -1245,6 +1267,7 @@ class TFBartModel(TFBartPretrainedModel):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
...@@ -1253,6 +1276,7 @@ class TFBartModel(TFBartPretrainedModel): ...@@ -1253,6 +1276,7 @@ class TFBartModel(TFBartPretrainedModel):
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
...@@ -1309,6 +1333,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode ...@@ -1309,6 +1333,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1339,6 +1364,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode ...@@ -1339,6 +1364,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1372,6 +1398,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode ...@@ -1372,6 +1398,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
...@@ -1394,6 +1421,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode ...@@ -1394,6 +1421,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
past_key_values=outputs.past_key_values, # index 1 of d outputs past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out
...@@ -1403,6 +1431,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode ...@@ -1403,6 +1431,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
...@@ -1411,6 +1440,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode ...@@ -1411,6 +1440,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
......
...@@ -32,7 +32,7 @@ from ...file_utils import ( ...@@ -32,7 +32,7 @@ from ...file_utils import (
) )
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPast, TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput, TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput, TFSeq2SeqModelOutput,
) )
...@@ -370,7 +370,7 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer): ...@@ -370,7 +370,7 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None, cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
...@@ -384,8 +384,8 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer): ...@@ -384,8 +384,8 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)` `(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
`(encoder_attention_heads,)` `(decoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
...@@ -406,17 +406,18 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer): ...@@ -406,17 +406,18 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
# Cross-Attention Block # Cross-Attention Block
cross_attn_present_key_value = None cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
residual = hidden_states residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states = self.encoder_attn_layer_norm(hidden_states)
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, _, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
...@@ -437,6 +438,7 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer): ...@@ -437,6 +438,7 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
return ( return (
hidden_states, hidden_states,
self_attn_weights, self_attn_weights,
cross_attn_weights,
present_key_value, present_key_value,
) )
...@@ -569,7 +571,7 @@ BLENDERBOT_INPUTS_DOCSTRING = r""" ...@@ -569,7 +571,7 @@ BLENDERBOT_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
...@@ -577,6 +579,12 @@ BLENDERBOT_INPUTS_DOCSTRING = r""" ...@@ -577,6 +579,12 @@ BLENDERBOT_INPUTS_DOCSTRING = r"""
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
...@@ -674,7 +682,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer): ...@@ -674,7 +682,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
...@@ -818,7 +826,7 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -818,7 +826,7 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -860,14 +868,13 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -860,14 +868,13 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -904,7 +911,7 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -904,7 +911,7 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
...@@ -957,19 +964,21 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -957,19 +964,21 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=inputs["training"])
# decoder layers # decoder layers
all_hidden_states = () all_hidden_states = () if inputs["output_hidden_states"] else None
all_self_attns = () all_self_attns = () if inputs["output_attentions"] else None
present_key_values = () all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
present_key_values = () if inputs["use_cache"] else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly(): for attn_mask in ["head_mask", "cross_attn_head_mask"]:
tf.debugging.assert_equal( if inputs[attn_mask] is not None and tf.executing_eagerly():
shape_list(inputs["head_mask"])[0], tf.debugging.assert_equal(
len(self.layers), shape_list(inputs[attn_mask])[0],
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", len(self.layers),
) message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
...@@ -982,14 +991,14 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -982,14 +991,14 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
hidden_states, layer_self_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx] cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
if inputs["encoder_head_mask"] is not None if inputs["cross_attn_head_mask"] is not None
else None, else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
...@@ -1000,25 +1009,32 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -1000,25 +1009,32 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns += (layer_cross_attn,)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
else:
all_hidden_states = None
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None if inputs["output_attentions"]:
all_self_attns = list(all_self_attns)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns = list(all_cross_attns)
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_values, past_key_values=present_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
cross_attentions=all_cross_attns,
) )
...@@ -1065,6 +1081,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer): ...@@ -1065,6 +1081,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1085,6 +1102,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer): ...@@ -1085,6 +1102,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1131,7 +1149,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer): ...@@ -1131,7 +1149,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"], head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"], cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
...@@ -1149,6 +1167,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer): ...@@ -1149,6 +1167,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
...@@ -1199,6 +1218,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel): ...@@ -1199,6 +1218,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1219,6 +1239,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel): ...@@ -1219,6 +1239,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1238,6 +1259,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel): ...@@ -1238,6 +1259,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
...@@ -1256,6 +1278,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel): ...@@ -1256,6 +1278,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
...@@ -1264,6 +1287,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel): ...@@ -1264,6 +1287,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
...@@ -1331,6 +1355,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal ...@@ -1331,6 +1355,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1361,6 +1386,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal ...@@ -1361,6 +1386,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1394,6 +1420,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal ...@@ -1394,6 +1420,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
...@@ -1416,6 +1443,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal ...@@ -1416,6 +1443,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
past_key_values=outputs.past_key_values, # index 1 of d outputs past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out
...@@ -1426,6 +1454,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal ...@@ -1426,6 +1454,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
...@@ -1434,6 +1463,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal ...@@ -1434,6 +1463,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
......
...@@ -30,7 +30,7 @@ from ...file_utils import ( ...@@ -30,7 +30,7 @@ from ...file_utils import (
) )
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPast, TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput, TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput, TFSeq2SeqModelOutput,
) )
...@@ -369,7 +369,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer): ...@@ -369,7 +369,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None, cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
...@@ -383,8 +383,8 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer): ...@@ -383,8 +383,8 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)` `(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
`(encoder_attention_heads,)` `(decoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
...@@ -405,16 +405,17 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer): ...@@ -405,16 +405,17 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
# Cross-Attention Block # Cross-Attention Block
cross_attn_present_key_value = None cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
residual = hidden_states residual = hidden_states
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, _, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
...@@ -436,6 +437,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer): ...@@ -436,6 +437,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
return ( return (
hidden_states, hidden_states,
self_attn_weights, self_attn_weights,
cross_attn_weights,
present_key_value, present_key_value,
) )
...@@ -574,7 +576,7 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r""" ...@@ -574,7 +576,7 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
...@@ -582,6 +584,12 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r""" ...@@ -582,6 +584,12 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r"""
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
...@@ -679,7 +687,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer): ...@@ -679,7 +687,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
...@@ -823,7 +831,7 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): ...@@ -823,7 +831,7 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -865,14 +873,13 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): ...@@ -865,14 +873,13 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -909,7 +916,7 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): ...@@ -909,7 +916,7 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
...@@ -960,19 +967,21 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): ...@@ -960,19 +967,21 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=inputs["training"])
# decoder layers # decoder layers
all_hidden_states = () all_hidden_states = () if inputs["output_hidden_states"] else None
all_self_attns = () all_self_attns = () if inputs["output_attentions"] else None
present_key_values = () all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
present_key_values = () if inputs["use_cache"] else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly(): for attn_mask in ["head_mask", "cross_attn_head_mask"]:
tf.debugging.assert_equal( if inputs[attn_mask] is not None and tf.executing_eagerly():
shape_list(inputs["head_mask"])[0], tf.debugging.assert_equal(
len(self.layers), shape_list(inputs[attn_mask])[0],
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", len(self.layers),
) message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
...@@ -985,14 +994,14 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): ...@@ -985,14 +994,14 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
hidden_states, layer_self_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx] cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
if inputs["encoder_head_mask"] is not None if inputs["cross_attn_head_mask"] is not None
else None, else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
...@@ -1003,23 +1012,30 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): ...@@ -1003,23 +1012,30 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns += (layer_cross_attn,)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
else:
all_hidden_states = None
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None if inputs["output_attentions"]:
all_self_attns = list(all_self_attns)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns = list(all_cross_attns)
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_values, past_key_values=present_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
cross_attentions=all_cross_attns,
) )
...@@ -1066,6 +1082,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer): ...@@ -1066,6 +1082,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1086,6 +1103,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer): ...@@ -1086,6 +1103,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1132,7 +1150,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer): ...@@ -1132,7 +1150,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"], head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"], cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
...@@ -1150,6 +1168,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer): ...@@ -1150,6 +1168,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
...@@ -1187,6 +1206,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel): ...@@ -1187,6 +1206,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1207,6 +1227,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel): ...@@ -1207,6 +1227,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1226,6 +1247,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel): ...@@ -1226,6 +1247,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
...@@ -1244,6 +1266,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel): ...@@ -1244,6 +1266,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
...@@ -1252,6 +1275,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel): ...@@ -1252,6 +1275,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
...@@ -1306,6 +1330,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1306,6 +1330,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1336,6 +1361,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1336,6 +1361,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1369,6 +1395,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1369,6 +1395,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
...@@ -1391,6 +1418,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1391,6 +1418,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
past_key_values=outputs.past_key_values, # index 1 of d outputs past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out
...@@ -1401,6 +1429,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1401,6 +1429,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
...@@ -1409,6 +1438,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1409,6 +1438,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
......
...@@ -31,7 +31,7 @@ from ...file_utils import ( ...@@ -31,7 +31,7 @@ from ...file_utils import (
) )
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPast, TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput, TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput, TFSeq2SeqModelOutput,
) )
...@@ -408,7 +408,7 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer): ...@@ -408,7 +408,7 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None, cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
...@@ -422,8 +422,8 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer): ...@@ -422,8 +422,8 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)` `(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
`(encoder_attention_heads,)` `(decoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
...@@ -444,16 +444,17 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer): ...@@ -444,16 +444,17 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
# Cross-Attention Block # Cross-Attention Block
cross_attn_present_key_value = None cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
residual = hidden_states residual = hidden_states
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, _, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
...@@ -475,6 +476,7 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer): ...@@ -475,6 +476,7 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
return ( return (
hidden_states, hidden_states,
self_attn_weights, self_attn_weights,
cross_attn_weights,
present_key_value, present_key_value,
) )
...@@ -603,7 +605,7 @@ MARIAN_INPUTS_DOCSTRING = r""" ...@@ -603,7 +605,7 @@ MARIAN_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
...@@ -611,6 +613,12 @@ MARIAN_INPUTS_DOCSTRING = r""" ...@@ -611,6 +613,12 @@ MARIAN_INPUTS_DOCSTRING = r"""
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
...@@ -707,7 +715,7 @@ class TFMarianEncoder(tf.keras.layers.Layer): ...@@ -707,7 +715,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
...@@ -848,7 +856,7 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -848,7 +856,7 @@ class TFMarianDecoder(tf.keras.layers.Layer):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -890,14 +898,13 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -890,14 +898,13 @@ class TFMarianDecoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -934,7 +941,7 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -934,7 +941,7 @@ class TFMarianDecoder(tf.keras.layers.Layer):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
...@@ -986,19 +993,21 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -986,19 +993,21 @@ class TFMarianDecoder(tf.keras.layers.Layer):
hidden_states = self.dropout(hidden_states + positions, training=inputs["training"]) hidden_states = self.dropout(hidden_states + positions, training=inputs["training"])
# decoder layers # decoder layers
all_hidden_states = () all_hidden_states = () if inputs["output_hidden_states"] else None
all_self_attns = () all_self_attns = () if inputs["output_attentions"] else None
present_key_values = () all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
present_key_values = () if inputs["use_cache"] else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly(): for attn_mask in ["head_mask", "cross_attn_head_mask"]:
tf.debugging.assert_equal( if inputs[attn_mask] is not None and tf.executing_eagerly():
shape_list(inputs["head_mask"])[0], tf.debugging.assert_equal(
len(self.layers), shape_list(inputs[attn_mask])[0],
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", len(self.layers),
) message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
...@@ -1011,14 +1020,14 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -1011,14 +1020,14 @@ class TFMarianDecoder(tf.keras.layers.Layer):
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
hidden_states, layer_self_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx] cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
if inputs["encoder_head_mask"] is not None if inputs["cross_attn_head_mask"] is not None
else None, else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
...@@ -1029,23 +1038,30 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -1029,23 +1038,30 @@ class TFMarianDecoder(tf.keras.layers.Layer):
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns += (layer_cross_attn,)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
else:
all_hidden_states = None
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None if inputs["output_attentions"]:
all_self_attns = list(all_self_attns)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns = list(all_cross_attns)
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_values, past_key_values=present_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
cross_attentions=all_cross_attns,
) )
...@@ -1092,6 +1108,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer): ...@@ -1092,6 +1108,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1112,6 +1129,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer): ...@@ -1112,6 +1129,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1161,7 +1179,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer): ...@@ -1161,7 +1179,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"], head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"], cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
...@@ -1179,6 +1197,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer): ...@@ -1179,6 +1197,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
...@@ -1216,6 +1235,7 @@ class TFMarianModel(TFMarianPreTrainedModel): ...@@ -1216,6 +1235,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1235,6 +1255,7 @@ class TFMarianModel(TFMarianPreTrainedModel): ...@@ -1235,6 +1255,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
...@@ -1255,6 +1276,7 @@ class TFMarianModel(TFMarianPreTrainedModel): ...@@ -1255,6 +1276,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
...@@ -1273,6 +1295,7 @@ class TFMarianModel(TFMarianPreTrainedModel): ...@@ -1273,6 +1295,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
...@@ -1281,6 +1304,7 @@ class TFMarianModel(TFMarianPreTrainedModel): ...@@ -1281,6 +1304,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
...@@ -1335,6 +1359,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1335,6 +1359,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1365,6 +1390,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1365,6 +1390,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1398,6 +1424,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1398,6 +1424,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
...@@ -1420,6 +1447,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1420,6 +1447,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
past_key_values=outputs.past_key_values, # index 1 of d outputs past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
encoder_last_hidden_state=outputs.last_hidden_state, # index 0 of encoder outputs encoder_last_hidden_state=outputs.last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out
...@@ -1430,6 +1458,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1430,6 +1458,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
...@@ -1438,6 +1467,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1438,6 +1467,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
......
...@@ -30,7 +30,7 @@ from ...file_utils import ( ...@@ -30,7 +30,7 @@ from ...file_utils import (
) )
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPast, TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput, TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput, TFSeq2SeqModelOutput,
) )
...@@ -368,7 +368,7 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer): ...@@ -368,7 +368,7 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None, cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
...@@ -382,8 +382,8 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer): ...@@ -382,8 +382,8 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)` `(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
`(encoder_attention_heads,)` `(decoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
...@@ -404,17 +404,18 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer): ...@@ -404,17 +404,18 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
# Cross-Attention Block # Cross-Attention Block
cross_attn_present_key_value = None cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
residual = hidden_states residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states = self.encoder_attn_layer_norm(hidden_states)
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, _, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
...@@ -435,6 +436,7 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer): ...@@ -435,6 +436,7 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
return ( return (
hidden_states, hidden_states,
self_attn_weights, self_attn_weights,
cross_attn_weights,
present_key_value, present_key_value,
) )
...@@ -547,7 +549,7 @@ MBART_INPUTS_DOCSTRING = r""" ...@@ -547,7 +549,7 @@ MBART_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
...@@ -555,6 +557,12 @@ MBART_INPUTS_DOCSTRING = r""" ...@@ -555,6 +557,12 @@ MBART_INPUTS_DOCSTRING = r"""
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
...@@ -828,7 +836,7 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -828,7 +836,7 @@ class TFMBartDecoder(tf.keras.layers.Layer):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -870,14 +878,13 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -870,14 +878,13 @@ class TFMBartDecoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -914,7 +921,7 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -914,7 +921,7 @@ class TFMBartDecoder(tf.keras.layers.Layer):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
...@@ -967,19 +974,21 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -967,19 +974,21 @@ class TFMBartDecoder(tf.keras.layers.Layer):
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=inputs["training"])
# decoder layers # decoder layers
all_hidden_states = () all_hidden_states = () if inputs["output_hidden_states"] else None
all_self_attns = () all_self_attns = () if inputs["output_attentions"] else None
present_key_values = () all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
present_key_values = () if inputs["use_cache"] else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly(): for attn_mask in ["head_mask", "cross_attn_head_mask"]:
tf.debugging.assert_equal( if inputs[attn_mask] is not None and tf.executing_eagerly():
shape_list(inputs["head_mask"])[0], tf.debugging.assert_equal(
len(self.layers), shape_list(inputs[attn_mask])[0],
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", len(self.layers),
) message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
...@@ -992,14 +1001,14 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -992,14 +1001,14 @@ class TFMBartDecoder(tf.keras.layers.Layer):
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
hidden_states, layer_self_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx] cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
if inputs["encoder_head_mask"] is not None if inputs["cross_attn_head_mask"] is not None
else None, else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
...@@ -1010,25 +1019,32 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -1010,25 +1019,32 @@ class TFMBartDecoder(tf.keras.layers.Layer):
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns += (layer_cross_attn,)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
else:
all_hidden_states = None
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None if inputs["output_attentions"]:
all_self_attns = list(all_self_attns)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns = list(all_cross_attns)
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_values, past_key_values=present_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
cross_attentions=all_cross_attns,
) )
...@@ -1075,6 +1091,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer): ...@@ -1075,6 +1091,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1095,6 +1112,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer): ...@@ -1095,6 +1112,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1147,7 +1165,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer): ...@@ -1147,7 +1165,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"], head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"], cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
...@@ -1165,6 +1183,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer): ...@@ -1165,6 +1183,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
...@@ -1202,6 +1221,7 @@ class TFMBartModel(TFMBartPreTrainedModel): ...@@ -1202,6 +1221,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1222,6 +1242,7 @@ class TFMBartModel(TFMBartPreTrainedModel): ...@@ -1222,6 +1242,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1241,6 +1262,7 @@ class TFMBartModel(TFMBartPreTrainedModel): ...@@ -1241,6 +1262,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
...@@ -1259,6 +1281,7 @@ class TFMBartModel(TFMBartPreTrainedModel): ...@@ -1259,6 +1281,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
...@@ -1267,6 +1290,7 @@ class TFMBartModel(TFMBartPreTrainedModel): ...@@ -1267,6 +1290,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
...@@ -1321,6 +1345,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo ...@@ -1321,6 +1345,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1351,6 +1376,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo ...@@ -1351,6 +1376,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1382,6 +1408,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo ...@@ -1382,6 +1408,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
...@@ -1404,6 +1431,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo ...@@ -1404,6 +1431,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
past_key_values=outputs.past_key_values, # index 1 of d outputs past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out
...@@ -1414,6 +1442,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo ...@@ -1414,6 +1442,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
...@@ -1422,6 +1451,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo ...@@ -1422,6 +1451,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
......
...@@ -31,7 +31,7 @@ from ...file_utils import ( ...@@ -31,7 +31,7 @@ from ...file_utils import (
) )
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPast, TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput, TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput, TFSeq2SeqModelOutput,
) )
...@@ -409,7 +409,7 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer): ...@@ -409,7 +409,7 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None, cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
...@@ -423,8 +423,8 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer): ...@@ -423,8 +423,8 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)` `(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
`(encoder_attention_heads,)` `(decoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
...@@ -445,17 +445,18 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer): ...@@ -445,17 +445,18 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
# Cross-Attention Block # Cross-Attention Block
cross_attn_present_key_value = None cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
residual = hidden_states residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states = self.encoder_attn_layer_norm(hidden_states)
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, _, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
...@@ -476,6 +477,7 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer): ...@@ -476,6 +477,7 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
return ( return (
hidden_states, hidden_states,
self_attn_weights, self_attn_weights,
cross_attn_weights,
present_key_value, present_key_value,
) )
...@@ -603,7 +605,7 @@ PEGASUS_INPUTS_DOCSTRING = r""" ...@@ -603,7 +605,7 @@ PEGASUS_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
...@@ -611,6 +613,12 @@ PEGASUS_INPUTS_DOCSTRING = r""" ...@@ -611,6 +613,12 @@ PEGASUS_INPUTS_DOCSTRING = r"""
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
...@@ -855,7 +863,7 @@ class TFPegasusDecoder(tf.keras.layers.Layer): ...@@ -855,7 +863,7 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None, head_mask=None,
encoder_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -897,14 +905,13 @@ class TFPegasusDecoder(tf.keras.layers.Layer): ...@@ -897,14 +905,13 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**. - 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
...@@ -941,7 +948,7 @@ class TFPegasusDecoder(tf.keras.layers.Layer): ...@@ -941,7 +948,7 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_head_mask=encoder_head_mask, cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
...@@ -993,19 +1000,21 @@ class TFPegasusDecoder(tf.keras.layers.Layer): ...@@ -993,19 +1000,21 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
hidden_states = self.dropout(hidden_states + positions, training=inputs["training"]) hidden_states = self.dropout(hidden_states + positions, training=inputs["training"])
# decoder layers # decoder layers
all_hidden_states = () all_hidden_states = () if inputs["output_hidden_states"] else None
all_self_attns = () all_self_attns = () if inputs["output_attentions"] else None
present_key_values = () all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
present_key_values = () if inputs["use_cache"] else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly(): for attn_mask in ["head_mask", "cross_attn_head_mask"]:
tf.debugging.assert_equal( if inputs[attn_mask] is not None and tf.executing_eagerly():
shape_list(inputs["head_mask"])[0], tf.debugging.assert_equal(
len(self.layers), shape_list(inputs[attn_mask])[0],
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", len(self.layers),
) message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
...@@ -1018,14 +1027,14 @@ class TFPegasusDecoder(tf.keras.layers.Layer): ...@@ -1018,14 +1027,14 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
hidden_states, layer_self_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx] cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
if inputs["encoder_head_mask"] is not None if inputs["cross_attn_head_mask"] is not None
else None, else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
...@@ -1036,25 +1045,32 @@ class TFPegasusDecoder(tf.keras.layers.Layer): ...@@ -1036,25 +1045,32 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns += (layer_cross_attn,)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
else:
all_hidden_states = None
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None if inputs["output_attentions"]:
all_self_attns = list(all_self_attns)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns = list(all_cross_attns)
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_values, past_key_values=present_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
cross_attentions=all_cross_attns,
) )
...@@ -1101,6 +1117,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer): ...@@ -1101,6 +1117,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1121,6 +1138,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer): ...@@ -1121,6 +1138,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1170,7 +1188,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer): ...@@ -1170,7 +1188,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"], head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"], cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
...@@ -1188,6 +1206,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer): ...@@ -1188,6 +1206,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
...@@ -1225,6 +1244,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel): ...@@ -1225,6 +1244,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1245,6 +1265,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel): ...@@ -1245,6 +1265,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1264,6 +1285,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel): ...@@ -1264,6 +1285,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
...@@ -1282,6 +1304,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel): ...@@ -1282,6 +1304,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
...@@ -1290,6 +1313,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel): ...@@ -1290,6 +1313,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
...@@ -1344,6 +1368,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua ...@@ -1344,6 +1368,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1374,6 +1399,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua ...@@ -1374,6 +1399,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1407,6 +1433,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua ...@@ -1407,6 +1433,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"], decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
...@@ -1429,6 +1456,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua ...@@ -1429,6 +1456,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
past_key_values=outputs.past_key_values, # index 1 of d outputs past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out
...@@ -1439,6 +1467,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua ...@@ -1439,6 +1467,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
...@@ -1447,6 +1476,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua ...@@ -1447,6 +1476,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
......
...@@ -147,7 +147,6 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer): ...@@ -147,7 +147,6 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
return final_embeddings return final_embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}}
class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer): class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs): def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
...@@ -352,6 +351,7 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer): ...@@ -352,6 +351,7 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer):
return outputs return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->{{cookiecutter.camelcase_modelname}}
class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs): def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
...@@ -625,7 +625,6 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel): ...@@ -625,7 +625,6 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
base_model_prefix = "{{cookiecutter.lowercase_modelname}}" base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
...@@ -885,6 +884,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca ...@@ -885,6 +884,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns) return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
@add_start_docstrings( @add_start_docstrings(
"""{{cookiecutter.modelname}} Model with a `language modeling` head on top for CLM fine-tuning. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING """{{cookiecutter.modelname}} Model with a `language modeling` head on top for CLM fine-tuning. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING
) )
...@@ -1728,16 +1728,18 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer): ...@@ -1728,16 +1728,18 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False): def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
""" """
Args: Args:
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`tf.Tensor`): attention mask of size attention_mask (:obj:`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`
""" """
residual = hidden_states residual = hidden_states
hidden_states, self_attn_weights, _ = self.self_attn( hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
) )
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
...@@ -1798,6 +1800,8 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer): ...@@ -1798,6 +1800,8 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
...@@ -1809,6 +1813,10 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer): ...@@ -1809,6 +1813,10 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)`
cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
`(decoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
...@@ -1821,6 +1829,7 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer): ...@@ -1821,6 +1829,7 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, hidden_states=hidden_states,
past_key_value=self_attn_past_key_value, past_key_value=self_attn_past_key_value,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -1828,15 +1837,17 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer): ...@@ -1828,15 +1837,17 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
# Cross-Attention Block # Cross-Attention Block
cross_attn_present_key_value = None cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
residual = hidden_states residual = hidden_states
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, _, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
...@@ -1858,6 +1869,7 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer): ...@@ -1858,6 +1869,7 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
return ( return (
hidden_states, hidden_states,
self_attn_weights, self_attn_weights,
cross_attn_layer_head_mask,
present_key_value, present_key_value,
) )
...@@ -1965,6 +1977,24 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel): ...@@ -1965,6 +1977,24 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
the right for denoising pre-training following the paper. the right for denoising pre-training following the paper.
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
...@@ -2013,7 +2043,6 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): ...@@ -2013,7 +2043,6 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
self.max_source_positions = config.max_position_embeddings self.max_source_positions = config.max_position_embeddings
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
self.embed_positions = TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding( self.embed_positions = TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
...@@ -2034,6 +2063,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): ...@@ -2034,6 +2063,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
input_ids=None, input_ids=None,
inputs_embeds=None, inputs_embeds=None,
attention_mask=None, attention_mask=None,
head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -2058,6 +2088,12 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): ...@@ -2058,6 +2088,12 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
...@@ -2082,6 +2118,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): ...@@ -2082,6 +2118,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
config=self.config, config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -2115,6 +2152,16 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): ...@@ -2115,6 +2152,16 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
encoder_states = () if inputs["output_hidden_states"] else None encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers # encoder layers
for encoder_layer in self.layers: for encoder_layer in self.layers:
...@@ -2125,7 +2172,11 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): ...@@ -2125,7 +2172,11 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
continue continue
hidden_states, attn = encoder_layer(hidden_states, inputs["attention_mask"]) hidden_states, attn = encoder_layer(
hidden_states,
inputs["attention_mask"],
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
)
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_attentions += (attn,) all_attentions += (attn,)
...@@ -2181,6 +2232,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer): ...@@ -2181,6 +2232,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -2218,6 +2271,18 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer): ...@@ -2218,6 +2271,18 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding. decoding.
...@@ -2252,6 +2317,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer): ...@@ -2252,6 +2317,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
...@@ -2297,8 +2364,20 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer): ...@@ -2297,8 +2364,20 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
# decoder layers # decoder layers
all_hidden_states = () if inputs["output_hidden_states"] else None all_hidden_states = () if inputs["output_hidden_states"] else None
all_self_attns = () if inputs["output_attentions"] else None all_self_attns = () if inputs["output_attentions"] else None
all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
present_key_values = () if inputs["use_cache"] else None present_key_values = () if inputs["use_cache"] else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask in ["head_mask", "cross_attn_head_mask"]:
if inputs[attn_mask] is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs[attn_mask])[0],
len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
...@@ -2311,11 +2390,15 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer): ...@@ -2311,11 +2390,15 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
hidden_states, layer_self_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
if inputs["cross_attn_head_mask"] is not None
else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
...@@ -2325,23 +2408,30 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer): ...@@ -2325,23 +2408,30 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns += (layer_cross_attn,)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_self_attns = list(all_self_attns) all_self_attns = list(all_self_attns)
if inputs["encoder_hidden_states"] is not None:
all_cross_attns = list(all_cross_attns)
if inputs["use_cache"]: if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values) present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_values, past_key_values=present_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
cross_attentions=all_cross_attns,
) )
@tf.function @tf.function
...@@ -2413,6 +2503,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -2413,6 +2503,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -2431,6 +2524,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -2431,6 +2524,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -2450,6 +2546,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -2450,6 +2546,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
inputs["encoder_outputs"] = self.encoder( inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
...@@ -2472,6 +2569,8 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -2472,6 +2569,8 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
attention_mask=inputs["decoder_attention_mask"], attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
...@@ -2489,6 +2588,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -2489,6 +2588,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=inputs["encoder_outputs"].attentions,
...@@ -2524,6 +2624,9 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -2524,6 +2624,9 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -2542,6 +2645,9 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -2542,6 +2645,9 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -2559,6 +2665,9 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -2559,6 +2665,9 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"], decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
...@@ -2577,6 +2686,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -2577,6 +2686,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
...@@ -2585,6 +2695,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -2585,6 +2695,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
...@@ -2637,6 +2748,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec ...@@ -2637,6 +2748,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -2672,6 +2786,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec ...@@ -2672,6 +2786,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -2698,6 +2815,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec ...@@ -2698,6 +2815,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
decoder_input_ids=inputs["decoder_input_ids"], decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
...@@ -2720,6 +2840,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec ...@@ -2720,6 +2840,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
past_key_values=outputs.past_key_values, # index 1 of d outputs past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out
...@@ -2730,6 +2851,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec ...@@ -2730,6 +2851,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
...@@ -2738,6 +2860,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec ...@@ -2738,6 +2860,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
......
...@@ -147,6 +147,7 @@ def prepare_bart_inputs_dict( ...@@ -147,6 +147,7 @@ def prepare_bart_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
...@@ -162,13 +163,16 @@ def prepare_bart_inputs_dict( ...@@ -162,13 +163,16 @@ def prepare_bart_inputs_dict(
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }
......
...@@ -146,6 +146,7 @@ def prepare_blenderbot_inputs_dict( ...@@ -146,6 +146,7 @@ def prepare_blenderbot_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
...@@ -161,6 +162,8 @@ def prepare_blenderbot_inputs_dict( ...@@ -161,6 +162,8 @@ def prepare_blenderbot_inputs_dict(
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
...@@ -168,6 +171,7 @@ def prepare_blenderbot_inputs_dict( ...@@ -168,6 +171,7 @@ def prepare_blenderbot_inputs_dict(
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }
......
...@@ -146,6 +146,7 @@ def prepare_blenderbot_small_inputs_dict( ...@@ -146,6 +146,7 @@ def prepare_blenderbot_small_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
...@@ -161,6 +162,8 @@ def prepare_blenderbot_small_inputs_dict( ...@@ -161,6 +162,8 @@ def prepare_blenderbot_small_inputs_dict(
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
...@@ -168,6 +171,7 @@ def prepare_blenderbot_small_inputs_dict( ...@@ -168,6 +171,7 @@ def prepare_blenderbot_small_inputs_dict(
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }
......
...@@ -190,8 +190,12 @@ class TFModelTesterMixin: ...@@ -190,8 +190,12 @@ class TFModelTesterMixin:
"decoder_attention_mask", "decoder_attention_mask",
] ]
expected_arg_names.extend( expected_arg_names.extend(
["head_mask", "decoder_head_mask", "encoder_outputs"] ["head_mask", "decoder_head_mask"] if "head_mask" and "decoder_head_mask" in arg_names else []
if "head_mask" and "decoder_head_mask" in arg_names )
# Necessary to handle BART with newly added cross_attn_head_mask
expected_arg_names.extend(
["cross_attn_head_mask", "encoder_outputs"]
if "cross_attn_head_mask" in arg_names
else ["encoder_outputs"] else ["encoder_outputs"]
) )
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
...@@ -512,6 +516,8 @@ class TFModelTesterMixin: ...@@ -512,6 +516,8 @@ class TFModelTesterMixin:
del inputs_dict["head_mask"] del inputs_dict["head_mask"]
if "decoder_head_mask" in inputs_dict: if "decoder_head_mask" in inputs_dict:
del inputs_dict["decoder_head_mask"] del inputs_dict["decoder_head_mask"]
if "cross_attn_head_mask" in inputs_dict:
del inputs_dict["cross_attn_head_mask"]
tf_main_layer_classes = set( tf_main_layer_classes = set(
module_member module_member
for model_class in self.all_model_classes for model_class in self.all_model_classes
...@@ -639,7 +645,7 @@ class TFModelTesterMixin: ...@@ -639,7 +645,7 @@ class TFModelTesterMixin:
def check_decoder_attentions_output(outputs): def check_decoder_attentions_output(outputs):
out_len = len(outputs) out_len = len(outputs)
self.assertEqual(out_len % 2, 0) self.assertEqual(min(out_len % 2, out_len % 5), 0) # differentiation due to newly added cross_attentions
decoder_attentions = outputs.decoder_attentions decoder_attentions = outputs.decoder_attentions
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual( self.assertListEqual(
...@@ -733,6 +739,8 @@ class TFModelTesterMixin: ...@@ -733,6 +739,8 @@ class TFModelTesterMixin:
arg_names = [*signature.parameters.keys()] arg_names = [*signature.parameters.keys()]
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
inputs["decoder_head_mask"] = head_mask inputs["decoder_head_mask"] = head_mask
if "cross_attn_head_mask" in arg_names:
inputs["cross_attn_head_mask"] = head_mask
outputs = model(**inputs, return_dict=True) outputs = model(**inputs, return_dict=True)
...@@ -757,6 +765,8 @@ class TFModelTesterMixin: ...@@ -757,6 +765,8 @@ class TFModelTesterMixin:
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
check_attentions_validity(outputs.encoder_attentions) check_attentions_validity(outputs.encoder_attentions)
check_attentions_validity(outputs.decoder_attentions) check_attentions_validity(outputs.decoder_attentions)
if "cross_attn_head_mask" in arg_names:
check_attentions_validity(outputs.cross_attentions)
else: else:
check_attentions_validity(outputs.attentions) check_attentions_validity(outputs.attentions)
......
...@@ -148,6 +148,7 @@ def prepare_marian_inputs_dict( ...@@ -148,6 +148,7 @@ def prepare_marian_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
...@@ -163,6 +164,8 @@ def prepare_marian_inputs_dict( ...@@ -163,6 +164,8 @@ def prepare_marian_inputs_dict(
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
...@@ -170,6 +173,7 @@ def prepare_marian_inputs_dict( ...@@ -170,6 +173,7 @@ def prepare_marian_inputs_dict(
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }
......
...@@ -150,6 +150,7 @@ def prepare_mbart_inputs_dict( ...@@ -150,6 +150,7 @@ def prepare_mbart_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
...@@ -165,13 +166,16 @@ def prepare_mbart_inputs_dict( ...@@ -165,13 +166,16 @@ def prepare_mbart_inputs_dict(
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }
......
...@@ -146,6 +146,7 @@ def prepare_pegasus_inputs_dict( ...@@ -146,6 +146,7 @@ def prepare_pegasus_inputs_dict(
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
...@@ -161,6 +162,8 @@ def prepare_pegasus_inputs_dict( ...@@ -161,6 +162,8 @@ def prepare_pegasus_inputs_dict(
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
...@@ -168,6 +171,7 @@ def prepare_pegasus_inputs_dict( ...@@ -168,6 +171,7 @@ def prepare_pegasus_inputs_dict(
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment