Unverified Commit 1867d9a8 authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Add head_mask/decoder_head_mask for TF BART models (#9639)

* Add head_mask/decoder_head_mask for TF BART models

* Add head_mask and decoder_head_mask input arguments for TF BART-based
models as a TF counterpart to the PR #9569

* Add test_headmasking functionality to tests/test_modeling_tf_common.py

* TODO: Add a test to verify that we can get a gradient back for
importance score computation

* Remove redundant #TODO note

Remove redundant #TODO note from tests/test_modeling_tf_common.py

* Fix assertions

* Make style

* Fix ...Model input args and adjust one new test

* Add back head_mask and decoder_head_mask to BART-based ...Model
after the last commit

* Remove head_mask ande decoder_head_mask from input_dict
in TF test_train_pipeline_custom_model as these two have different
shape than other input args (Necessary for passing this test)

* Revert adding global_rng in test_modeling_tf_common.py
parent cb73ab5a
...@@ -164,6 +164,7 @@ class TFBartAttention(tf.keras.layers.Layer): ...@@ -164,6 +164,7 @@ class TFBartAttention(tf.keras.layers.Layer):
key_value_states: Optional[tf.Tensor] = None, key_value_states: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
...@@ -230,6 +231,17 @@ class TFBartAttention(tf.keras.layers.Layer): ...@@ -230,6 +231,17 @@ class TFBartAttention(tf.keras.layers.Layer):
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training) attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
...@@ -266,16 +278,18 @@ class TFBartEncoderLayer(tf.keras.layers.Layer): ...@@ -266,16 +278,18 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False): def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
""" """
Args: Args:
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`tf.Tensor`): attention mask of size attention_mask (:obj:`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`
""" """
residual = hidden_states residual = hidden_states
hidden_states, self_attn_weights, _ = self.self_attn( hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
) )
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(hidden_states), shape_list(hidden_states),
...@@ -331,6 +345,8 @@ class TFBartDecoderLayer(tf.keras.layers.Layer): ...@@ -331,6 +345,8 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
...@@ -342,6 +358,10 @@ class TFBartDecoderLayer(tf.keras.layers.Layer): ...@@ -342,6 +358,10 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
`(encoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
...@@ -354,6 +374,7 @@ class TFBartDecoderLayer(tf.keras.layers.Layer): ...@@ -354,6 +374,7 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, hidden_states=hidden_states,
past_key_value=self_attn_past_key_value, past_key_value=self_attn_past_key_value,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -370,6 +391,7 @@ class TFBartDecoderLayer(tf.keras.layers.Layer): ...@@ -370,6 +391,7 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
...@@ -527,6 +549,18 @@ BART_INPUTS_DOCSTRING = r""" ...@@ -527,6 +549,18 @@ BART_INPUTS_DOCSTRING = r"""
the right for denoising pre-training following the paper. the right for denoising pre-training following the paper.
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
...@@ -593,6 +627,7 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -593,6 +627,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
input_ids=None, input_ids=None,
inputs_embeds=None, inputs_embeds=None,
attention_mask=None, attention_mask=None,
head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -617,6 +652,12 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -617,6 +652,12 @@ class TFBartEncoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
...@@ -635,6 +676,7 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -635,6 +676,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
config=self.config, config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -670,8 +712,15 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -670,8 +712,15 @@ class TFBartEncoder(tf.keras.layers.Layer):
encoder_states = () if inputs["output_hidden_states"] else None encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers # encoder layers
for encoder_layer in self.layers: for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
...@@ -680,7 +729,11 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -680,7 +729,11 @@ class TFBartEncoder(tf.keras.layers.Layer):
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
continue continue
hidden_states, attn = encoder_layer(hidden_states, attention_mask) hidden_states, attn = encoder_layer(
hidden_states,
attention_mask,
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
)
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_attentions += (attn,) all_attentions += (attn,)
...@@ -737,6 +790,8 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -737,6 +790,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -774,6 +829,19 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -774,6 +829,19 @@ class TFBartDecoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding. decoding.
...@@ -802,6 +870,8 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -802,6 +870,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
...@@ -858,6 +928,13 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -858,6 +928,13 @@ class TFBartDecoder(tf.keras.layers.Layer):
all_self_attns = () if inputs["output_attentions"] else None all_self_attns = () if inputs["output_attentions"] else None
present_key_values = () if inputs["use_cache"] else None present_key_values = () if inputs["use_cache"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
...@@ -875,6 +952,10 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -875,6 +952,10 @@ class TFBartDecoder(tf.keras.layers.Layer):
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
...@@ -945,6 +1026,8 @@ class TFBartMainLayer(tf.keras.layers.Layer): ...@@ -945,6 +1026,8 @@ class TFBartMainLayer(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -963,6 +1046,8 @@ class TFBartMainLayer(tf.keras.layers.Layer): ...@@ -963,6 +1046,8 @@ class TFBartMainLayer(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -993,6 +1078,7 @@ class TFBartMainLayer(tf.keras.layers.Layer): ...@@ -993,6 +1078,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
inputs["encoder_outputs"] = self.encoder( inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
...@@ -1015,6 +1101,8 @@ class TFBartMainLayer(tf.keras.layers.Layer): ...@@ -1015,6 +1101,8 @@ class TFBartMainLayer(tf.keras.layers.Layer):
attention_mask=inputs["decoder_attention_mask"], attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
...@@ -1067,6 +1155,8 @@ class TFBartModel(TFBartPretrainedModel): ...@@ -1067,6 +1155,8 @@ class TFBartModel(TFBartPretrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1085,6 +1175,8 @@ class TFBartModel(TFBartPretrainedModel): ...@@ -1085,6 +1175,8 @@ class TFBartModel(TFBartPretrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1102,6 +1194,8 @@ class TFBartModel(TFBartPretrainedModel): ...@@ -1102,6 +1194,8 @@ class TFBartModel(TFBartPretrainedModel):
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"], decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
...@@ -1179,6 +1273,8 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel): ...@@ -1179,6 +1273,8 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1207,6 +1303,8 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel): ...@@ -1207,6 +1303,8 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1233,6 +1331,8 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel): ...@@ -1233,6 +1331,8 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
decoder_input_ids=inputs["decoder_input_ids"], decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
...@@ -1277,7 +1377,15 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel): ...@@ -1277,7 +1377,15 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
) )
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict: def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1: if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
...@@ -1309,6 +1417,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel): ...@@ -1309,6 +1417,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
"past_key_values": past_key_values, "past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
......
...@@ -167,6 +167,7 @@ class TFBlenderbotAttention(tf.keras.layers.Layer): ...@@ -167,6 +167,7 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
key_value_states: Optional[tf.Tensor] = None, key_value_states: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
...@@ -233,6 +234,17 @@ class TFBlenderbotAttention(tf.keras.layers.Layer): ...@@ -233,6 +234,17 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training) attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
...@@ -270,17 +282,19 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer): ...@@ -270,17 +282,19 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False): def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
""" """
Args: Args:
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`tf.Tensor`): attention mask of size attention_mask (:obj:`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`
""" """
residual = hidden_states residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, self_attn_weights, _ = self.self_attn( hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
) )
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(hidden_states), shape_list(hidden_states),
...@@ -336,6 +350,8 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer): ...@@ -336,6 +350,8 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
...@@ -347,6 +363,10 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer): ...@@ -347,6 +363,10 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
`(encoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
...@@ -360,6 +380,7 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer): ...@@ -360,6 +380,7 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, hidden_states=hidden_states,
past_key_value=self_attn_past_key_value, past_key_value=self_attn_past_key_value,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -376,6 +397,7 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer): ...@@ -376,6 +397,7 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
...@@ -524,6 +546,18 @@ BLENDERBOT_INPUTS_DOCSTRING = r""" ...@@ -524,6 +546,18 @@ BLENDERBOT_INPUTS_DOCSTRING = r"""
:obj:`past_key_values`). :obj:`past_key_values`).
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
...@@ -590,6 +624,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer): ...@@ -590,6 +624,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
input_ids=None, input_ids=None,
inputs_embeds=None, inputs_embeds=None,
attention_mask=None, attention_mask=None,
head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -614,6 +649,12 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer): ...@@ -614,6 +649,12 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
...@@ -632,6 +673,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer): ...@@ -632,6 +673,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
config=self.config, config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -666,8 +708,15 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer): ...@@ -666,8 +708,15 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
encoder_states = () if inputs["output_hidden_states"] else None encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers # encoder layers
for encoder_layer in self.layers: for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
...@@ -676,7 +725,11 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer): ...@@ -676,7 +725,11 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
continue continue
hidden_states, attn = encoder_layer(hidden_states, attention_mask) hidden_states, attn = encoder_layer(
hidden_states,
attention_mask,
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
)
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_attentions += (attn,) all_attentions += (attn,)
...@@ -735,6 +788,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -735,6 +788,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -772,6 +827,19 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -772,6 +827,19 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding. decoding.
...@@ -800,6 +868,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -800,6 +868,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
...@@ -855,6 +925,14 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -855,6 +925,14 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
all_hidden_states = () all_hidden_states = ()
all_self_attns = () all_self_attns = ()
present_key_values = () present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
...@@ -871,6 +949,10 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -871,6 +949,10 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
...@@ -943,6 +1025,8 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer): ...@@ -943,6 +1025,8 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -961,6 +1045,8 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer): ...@@ -961,6 +1045,8 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -983,6 +1069,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer): ...@@ -983,6 +1069,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
inputs["encoder_outputs"] = self.encoder( inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
...@@ -1005,6 +1092,8 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer): ...@@ -1005,6 +1092,8 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
attention_mask=inputs["decoder_attention_mask"], attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
...@@ -1070,6 +1159,8 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel): ...@@ -1070,6 +1159,8 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1088,6 +1179,8 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel): ...@@ -1088,6 +1179,8 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1105,6 +1198,8 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel): ...@@ -1105,6 +1198,8 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"], decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
...@@ -1196,6 +1291,8 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel): ...@@ -1196,6 +1291,8 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1224,6 +1321,8 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel): ...@@ -1224,6 +1321,8 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1249,6 +1348,8 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel): ...@@ -1249,6 +1348,8 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
decoder_input_ids=inputs["decoder_input_ids"], decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
...@@ -1295,7 +1396,15 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel): ...@@ -1295,7 +1396,15 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
) )
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict: def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1: if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
...@@ -1327,6 +1436,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel): ...@@ -1327,6 +1436,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
"past_key_values": past_key_values, "past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
......
...@@ -166,6 +166,7 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer): ...@@ -166,6 +166,7 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
key_value_states: Optional[tf.Tensor] = None, key_value_states: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
...@@ -232,6 +233,17 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer): ...@@ -232,6 +233,17 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training) attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
...@@ -269,16 +281,18 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer): ...@@ -269,16 +281,18 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False): def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
""" """
Args: Args:
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`tf.Tensor`): attention mask of size attention_mask (:obj:`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`
""" """
residual = hidden_states residual = hidden_states
hidden_states, self_attn_weights, _ = self.self_attn( hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
) )
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(hidden_states), shape_list(hidden_states),
...@@ -335,6 +349,8 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer): ...@@ -335,6 +349,8 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
...@@ -346,6 +362,10 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer): ...@@ -346,6 +362,10 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
`(encoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
...@@ -358,6 +378,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer): ...@@ -358,6 +378,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, hidden_states=hidden_states,
past_key_value=self_attn_past_key_value, past_key_value=self_attn_past_key_value,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -374,6 +395,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer): ...@@ -374,6 +395,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
...@@ -529,6 +551,18 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r""" ...@@ -529,6 +551,18 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r"""
:obj:`past_key_values`). :obj:`past_key_values`).
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
...@@ -595,6 +629,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer): ...@@ -595,6 +629,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
input_ids=None, input_ids=None,
inputs_embeds=None, inputs_embeds=None,
attention_mask=None, attention_mask=None,
head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -619,6 +654,12 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer): ...@@ -619,6 +654,12 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
...@@ -637,6 +678,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer): ...@@ -637,6 +678,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
config=self.config, config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -672,8 +714,15 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer): ...@@ -672,8 +714,15 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
encoder_states = () if inputs["output_hidden_states"] else None encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers # encoder layers
for encoder_layer in self.layers: for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
...@@ -682,7 +731,11 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer): ...@@ -682,7 +731,11 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
continue continue
hidden_states, attn = encoder_layer(hidden_states, attention_mask) hidden_states, attn = encoder_layer(
hidden_states,
attention_mask,
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
)
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_attentions += (attn,) all_attentions += (attn,)
...@@ -740,6 +793,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): ...@@ -740,6 +793,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -777,6 +832,19 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): ...@@ -777,6 +832,19 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding. decoding.
...@@ -805,6 +873,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): ...@@ -805,6 +873,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
...@@ -859,6 +929,13 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): ...@@ -859,6 +929,13 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
all_self_attns = () all_self_attns = ()
present_key_values = () present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
...@@ -875,6 +952,10 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): ...@@ -875,6 +952,10 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
...@@ -945,6 +1026,8 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer): ...@@ -945,6 +1026,8 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -963,6 +1046,8 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer): ...@@ -963,6 +1046,8 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -985,6 +1070,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer): ...@@ -985,6 +1070,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
inputs["encoder_outputs"] = self.encoder( inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
...@@ -1007,6 +1093,8 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer): ...@@ -1007,6 +1093,8 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
attention_mask=inputs["decoder_attention_mask"], attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
...@@ -1059,6 +1147,8 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel): ...@@ -1059,6 +1147,8 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1077,6 +1167,8 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel): ...@@ -1077,6 +1167,8 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1094,6 +1186,8 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel): ...@@ -1094,6 +1186,8 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"], decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
...@@ -1172,6 +1266,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1172,6 +1266,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1200,6 +1296,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1200,6 +1296,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1225,6 +1323,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1225,6 +1323,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
decoder_input_ids=inputs["decoder_input_ids"], decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
...@@ -1271,7 +1371,15 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1271,7 +1371,15 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
) )
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict: def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1: if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
...@@ -1303,6 +1411,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1303,6 +1411,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
"past_key_values": past_key_values, "past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
......
...@@ -196,6 +196,7 @@ class TFMarianAttention(tf.keras.layers.Layer): ...@@ -196,6 +196,7 @@ class TFMarianAttention(tf.keras.layers.Layer):
key_value_states: Optional[tf.Tensor] = None, key_value_states: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
...@@ -262,6 +263,17 @@ class TFMarianAttention(tf.keras.layers.Layer): ...@@ -262,6 +263,17 @@ class TFMarianAttention(tf.keras.layers.Layer):
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training) attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
...@@ -299,16 +311,18 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer): ...@@ -299,16 +311,18 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False): def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
""" """
Args: Args:
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`tf.Tensor`): attention mask of size attention_mask (:obj:`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`
""" """
residual = hidden_states residual = hidden_states
hidden_states, self_attn_weights, _ = self.self_attn( hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
) )
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(hidden_states), shape_list(hidden_states),
...@@ -365,6 +379,8 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer): ...@@ -365,6 +379,8 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
...@@ -376,6 +392,10 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer): ...@@ -376,6 +392,10 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
`(encoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
...@@ -388,6 +408,7 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer): ...@@ -388,6 +408,7 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, hidden_states=hidden_states,
past_key_value=self_attn_past_key_value, past_key_value=self_attn_past_key_value,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -404,6 +425,7 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer): ...@@ -404,6 +425,7 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
...@@ -548,6 +570,18 @@ MARIAN_INPUTS_DOCSTRING = r""" ...@@ -548,6 +570,18 @@ MARIAN_INPUTS_DOCSTRING = r"""
:obj:`past_key_values`). :obj:`past_key_values`).
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
...@@ -612,6 +646,7 @@ class TFMarianEncoder(tf.keras.layers.Layer): ...@@ -612,6 +646,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
input_ids=None, input_ids=None,
inputs_embeds=None, inputs_embeds=None,
attention_mask=None, attention_mask=None,
head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -636,6 +671,12 @@ class TFMarianEncoder(tf.keras.layers.Layer): ...@@ -636,6 +671,12 @@ class TFMarianEncoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
...@@ -654,6 +695,7 @@ class TFMarianEncoder(tf.keras.layers.Layer): ...@@ -654,6 +695,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
config=self.config, config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -688,8 +730,15 @@ class TFMarianEncoder(tf.keras.layers.Layer): ...@@ -688,8 +730,15 @@ class TFMarianEncoder(tf.keras.layers.Layer):
encoder_states = () if inputs["output_hidden_states"] else None encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers # encoder layers
for encoder_layer in self.layers: for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
...@@ -698,7 +747,11 @@ class TFMarianEncoder(tf.keras.layers.Layer): ...@@ -698,7 +747,11 @@ class TFMarianEncoder(tf.keras.layers.Layer):
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
continue continue
hidden_states, attn = encoder_layer(hidden_states, attention_mask) hidden_states, attn = encoder_layer(
hidden_states,
attention_mask,
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
)
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_attentions += (attn,) all_attentions += (attn,)
...@@ -753,6 +806,8 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -753,6 +806,8 @@ class TFMarianDecoder(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -790,6 +845,19 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -790,6 +845,19 @@ class TFMarianDecoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding. decoding.
...@@ -818,6 +886,8 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -818,6 +886,8 @@ class TFMarianDecoder(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
...@@ -872,6 +942,14 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -872,6 +942,14 @@ class TFMarianDecoder(tf.keras.layers.Layer):
all_hidden_states = () all_hidden_states = ()
all_self_attns = () all_self_attns = ()
present_key_values = () present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
...@@ -888,6 +966,10 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -888,6 +966,10 @@ class TFMarianDecoder(tf.keras.layers.Layer):
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
...@@ -958,6 +1040,8 @@ class TFMarianMainLayer(tf.keras.layers.Layer): ...@@ -958,6 +1040,8 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -976,6 +1060,8 @@ class TFMarianMainLayer(tf.keras.layers.Layer): ...@@ -976,6 +1060,8 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1001,6 +1087,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer): ...@@ -1001,6 +1087,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
inputs["encoder_outputs"] = self.encoder( inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
...@@ -1023,6 +1110,8 @@ class TFMarianMainLayer(tf.keras.layers.Layer): ...@@ -1023,6 +1110,8 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
attention_mask=inputs["decoder_attention_mask"], attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
...@@ -1075,6 +1164,8 @@ class TFMarianModel(TFMarianPreTrainedModel): ...@@ -1075,6 +1164,8 @@ class TFMarianModel(TFMarianPreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1092,6 +1183,8 @@ class TFMarianModel(TFMarianPreTrainedModel): ...@@ -1092,6 +1183,8 @@ class TFMarianModel(TFMarianPreTrainedModel):
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
...@@ -1110,6 +1203,8 @@ class TFMarianModel(TFMarianPreTrainedModel): ...@@ -1110,6 +1203,8 @@ class TFMarianModel(TFMarianPreTrainedModel):
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"], decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
...@@ -1188,6 +1283,8 @@ class TFMarianMTModel(TFMarianPreTrainedModel): ...@@ -1188,6 +1283,8 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1216,6 +1313,8 @@ class TFMarianMTModel(TFMarianPreTrainedModel): ...@@ -1216,6 +1313,8 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1242,6 +1341,8 @@ class TFMarianMTModel(TFMarianPreTrainedModel): ...@@ -1242,6 +1341,8 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
decoder_input_ids=inputs["decoder_input_ids"], decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
...@@ -1288,7 +1389,15 @@ class TFMarianMTModel(TFMarianPreTrainedModel): ...@@ -1288,7 +1389,15 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
) )
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict: def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1: if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
...@@ -1320,6 +1429,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel): ...@@ -1320,6 +1429,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
"past_key_values": past_key_values, "past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
......
...@@ -170,6 +170,7 @@ class TFMBartAttention(tf.keras.layers.Layer): ...@@ -170,6 +170,7 @@ class TFMBartAttention(tf.keras.layers.Layer):
key_value_states: Optional[tf.Tensor] = None, key_value_states: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
...@@ -236,6 +237,17 @@ class TFMBartAttention(tf.keras.layers.Layer): ...@@ -236,6 +237,17 @@ class TFMBartAttention(tf.keras.layers.Layer):
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training) attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
...@@ -272,17 +284,19 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer): ...@@ -272,17 +284,19 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False): def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
""" """
Args: Args:
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`tf.Tensor`): attention mask of size attention_mask (:obj:`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`
""" """
residual = hidden_states residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, self_attn_weights, _ = self.self_attn( hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
) )
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(hidden_states), shape_list(hidden_states),
...@@ -337,6 +351,8 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer): ...@@ -337,6 +351,8 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
...@@ -348,6 +364,10 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer): ...@@ -348,6 +364,10 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
`(encoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
...@@ -361,6 +381,7 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer): ...@@ -361,6 +381,7 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, hidden_states=hidden_states,
past_key_value=self_attn_past_key_value, past_key_value=self_attn_past_key_value,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -377,6 +398,7 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer): ...@@ -377,6 +398,7 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
...@@ -505,6 +527,18 @@ MBART_INPUTS_DOCSTRING = r""" ...@@ -505,6 +527,18 @@ MBART_INPUTS_DOCSTRING = r"""
the right for denoising pre-training following the paper. the right for denoising pre-training following the paper.
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
...@@ -601,6 +635,7 @@ class TFMBartEncoder(tf.keras.layers.Layer): ...@@ -601,6 +635,7 @@ class TFMBartEncoder(tf.keras.layers.Layer):
input_ids=None, input_ids=None,
inputs_embeds=None, inputs_embeds=None,
attention_mask=None, attention_mask=None,
head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -625,6 +660,12 @@ class TFMBartEncoder(tf.keras.layers.Layer): ...@@ -625,6 +660,12 @@ class TFMBartEncoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
...@@ -643,6 +684,7 @@ class TFMBartEncoder(tf.keras.layers.Layer): ...@@ -643,6 +684,7 @@ class TFMBartEncoder(tf.keras.layers.Layer):
config=self.config, config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -678,8 +720,15 @@ class TFMBartEncoder(tf.keras.layers.Layer): ...@@ -678,8 +720,15 @@ class TFMBartEncoder(tf.keras.layers.Layer):
encoder_states = () if inputs["output_hidden_states"] else None encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers # encoder layers
for encoder_layer in self.layers: for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
...@@ -688,7 +737,11 @@ class TFMBartEncoder(tf.keras.layers.Layer): ...@@ -688,7 +737,11 @@ class TFMBartEncoder(tf.keras.layers.Layer):
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
continue continue
hidden_states, attn = encoder_layer(hidden_states, attention_mask) hidden_states, attn = encoder_layer(
hidden_states,
attention_mask,
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
)
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_attentions += (attn,) all_attentions += (attn,)
...@@ -748,6 +801,8 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -748,6 +801,8 @@ class TFMBartDecoder(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -785,6 +840,19 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -785,6 +840,19 @@ class TFMBartDecoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding. decoding.
...@@ -813,6 +881,8 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -813,6 +881,8 @@ class TFMBartDecoder(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
...@@ -868,6 +938,14 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -868,6 +938,14 @@ class TFMBartDecoder(tf.keras.layers.Layer):
all_hidden_states = () all_hidden_states = ()
all_self_attns = () all_self_attns = ()
present_key_values = () present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
...@@ -884,6 +962,10 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -884,6 +962,10 @@ class TFMBartDecoder(tf.keras.layers.Layer):
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
...@@ -956,6 +1038,8 @@ class TFMBartMainLayer(tf.keras.layers.Layer): ...@@ -956,6 +1038,8 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -974,6 +1058,8 @@ class TFMBartMainLayer(tf.keras.layers.Layer): ...@@ -974,6 +1058,8 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1002,6 +1088,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer): ...@@ -1002,6 +1088,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
inputs["encoder_outputs"] = self.encoder( inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
...@@ -1024,6 +1111,8 @@ class TFMBartMainLayer(tf.keras.layers.Layer): ...@@ -1024,6 +1111,8 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
attention_mask=inputs["decoder_attention_mask"], attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
...@@ -1076,6 +1165,8 @@ class TFMBartModel(TFMBartPreTrainedModel): ...@@ -1076,6 +1165,8 @@ class TFMBartModel(TFMBartPreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1094,6 +1185,8 @@ class TFMBartModel(TFMBartPreTrainedModel): ...@@ -1094,6 +1185,8 @@ class TFMBartModel(TFMBartPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1111,6 +1204,8 @@ class TFMBartModel(TFMBartPreTrainedModel): ...@@ -1111,6 +1204,8 @@ class TFMBartModel(TFMBartPreTrainedModel):
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"], decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
...@@ -1189,6 +1284,8 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel): ...@@ -1189,6 +1284,8 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1217,6 +1314,8 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel): ...@@ -1217,6 +1314,8 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1241,6 +1340,8 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel): ...@@ -1241,6 +1340,8 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
decoder_input_ids=inputs["decoder_input_ids"], decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
...@@ -1287,7 +1388,15 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel): ...@@ -1287,7 +1388,15 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
) )
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict: def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1: if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
...@@ -1319,6 +1428,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel): ...@@ -1319,6 +1428,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
"past_key_values": past_key_values, "past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
......
...@@ -197,6 +197,7 @@ class TFPegasusAttention(tf.keras.layers.Layer): ...@@ -197,6 +197,7 @@ class TFPegasusAttention(tf.keras.layers.Layer):
key_value_states: Optional[tf.Tensor] = None, key_value_states: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
...@@ -263,6 +264,17 @@ class TFPegasusAttention(tf.keras.layers.Layer): ...@@ -263,6 +264,17 @@ class TFPegasusAttention(tf.keras.layers.Layer):
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training) attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
...@@ -300,17 +312,19 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer): ...@@ -300,17 +312,19 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False): def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
""" """
Args: Args:
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`tf.Tensor`): attention mask of size attention_mask (:obj:`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`
""" """
residual = hidden_states residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, self_attn_weights, _ = self.self_attn( hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
) )
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(hidden_states), shape_list(hidden_states),
...@@ -366,6 +380,8 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer): ...@@ -366,6 +380,8 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False, training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
...@@ -377,6 +393,10 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer): ...@@ -377,6 +393,10 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
`(encoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
""" """
residual = hidden_states residual = hidden_states
...@@ -390,6 +410,7 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer): ...@@ -390,6 +410,7 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, hidden_states=hidden_states,
past_key_value=self_attn_past_key_value, past_key_value=self_attn_past_key_value,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -406,6 +427,7 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer): ...@@ -406,6 +427,7 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
) )
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
...@@ -553,6 +575,18 @@ PEGASUS_INPUTS_DOCSTRING = r""" ...@@ -553,6 +575,18 @@ PEGASUS_INPUTS_DOCSTRING = r"""
the right for denoising pre-training following the paper. the right for denoising pre-training following the paper.
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`): encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
...@@ -618,6 +652,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer): ...@@ -618,6 +652,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
input_ids=None, input_ids=None,
inputs_embeds=None, inputs_embeds=None,
attention_mask=None, attention_mask=None,
head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
...@@ -642,6 +677,12 @@ class TFPegasusEncoder(tf.keras.layers.Layer): ...@@ -642,6 +677,12 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
...@@ -660,6 +701,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer): ...@@ -660,6 +701,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
config=self.config, config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -694,8 +736,15 @@ class TFPegasusEncoder(tf.keras.layers.Layer): ...@@ -694,8 +736,15 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
encoder_states = () if inputs["output_hidden_states"] else None encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers # encoder layers
for encoder_layer in self.layers: for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
...@@ -704,7 +753,11 @@ class TFPegasusEncoder(tf.keras.layers.Layer): ...@@ -704,7 +753,11 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
continue continue
hidden_states, attn = encoder_layer(hidden_states, attention_mask) hidden_states, attn = encoder_layer(
hidden_states,
attention_mask,
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
)
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_attentions += (attn,) all_attentions += (attn,)
...@@ -762,6 +815,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer): ...@@ -762,6 +815,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -799,6 +854,19 @@ class TFPegasusDecoder(tf.keras.layers.Layer): ...@@ -799,6 +854,19 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding. decoding.
...@@ -827,6 +895,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer): ...@@ -827,6 +895,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
...@@ -881,6 +951,14 @@ class TFPegasusDecoder(tf.keras.layers.Layer): ...@@ -881,6 +951,14 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
all_hidden_states = () all_hidden_states = ()
all_self_attns = () all_self_attns = ()
present_key_values = () present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
...@@ -897,6 +975,10 @@ class TFPegasusDecoder(tf.keras.layers.Layer): ...@@ -897,6 +975,10 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
...@@ -969,6 +1051,8 @@ class TFPegasusMainLayer(tf.keras.layers.Layer): ...@@ -969,6 +1051,8 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -987,6 +1071,8 @@ class TFPegasusMainLayer(tf.keras.layers.Layer): ...@@ -987,6 +1071,8 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1012,6 +1098,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer): ...@@ -1012,6 +1098,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
inputs["encoder_outputs"] = self.encoder( inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
...@@ -1034,6 +1121,8 @@ class TFPegasusMainLayer(tf.keras.layers.Layer): ...@@ -1034,6 +1121,8 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
attention_mask=inputs["decoder_attention_mask"], attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
...@@ -1086,6 +1175,8 @@ class TFPegasusModel(TFPegasusPreTrainedModel): ...@@ -1086,6 +1175,8 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1104,6 +1195,8 @@ class TFPegasusModel(TFPegasusPreTrainedModel): ...@@ -1104,6 +1195,8 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1121,6 +1214,8 @@ class TFPegasusModel(TFPegasusPreTrainedModel): ...@@ -1121,6 +1214,8 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"], decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
...@@ -1199,6 +1294,8 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel): ...@@ -1199,6 +1294,8 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1227,6 +1324,8 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel): ...@@ -1227,6 +1324,8 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -1253,6 +1352,8 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel): ...@@ -1253,6 +1352,8 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
decoder_input_ids=inputs["decoder_input_ids"], decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
...@@ -1299,7 +1400,15 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel): ...@@ -1299,7 +1400,15 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
) )
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict: def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1: if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
...@@ -1331,6 +1440,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel): ...@@ -1331,6 +1440,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
"past_key_values": past_key_values, "past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
......
...@@ -240,6 +240,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -240,6 +240,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFAlbertModelTester(self) self.model_tester = TFAlbertModelTester(self)
......
...@@ -108,10 +108,11 @@ class TFBartModelTester: ...@@ -108,10 +108,11 @@ class TFBartModelTester:
input_ids = input_ids[:1, :] input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :] attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1 self.batch_size = 1
# first forward pass # first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple() output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1] past_key_values = past_key_values[1]
...@@ -144,6 +145,8 @@ def prepare_bart_inputs_dict( ...@@ -144,6 +145,8 @@ def prepare_bart_inputs_dict(
decoder_input_ids, decoder_input_ids,
attention_mask=None, attention_mask=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
...@@ -155,11 +158,17 @@ def prepare_bart_inputs_dict( ...@@ -155,11 +158,17 @@ def prepare_bart_inputs_dict(
], ],
axis=-1, axis=-1,
) )
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": head_mask,
} }
...@@ -169,6 +178,7 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -169,6 +178,7 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_head_masking = True
def setUp(self): def setUp(self):
self.model_tester = TFBartModelTester(self) self.model_tester = TFBartModelTester(self)
......
...@@ -273,6 +273,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -273,6 +273,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
test_head_masking = False
# special case for ForPreTraining model # special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
...@@ -107,10 +107,11 @@ class TFBlenderbotModelTester: ...@@ -107,10 +107,11 @@ class TFBlenderbotModelTester:
input_ids = input_ids[:1, :] input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :] attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1 self.batch_size = 1
# first forward pass # first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple() output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1] past_key_values = past_key_values[1]
...@@ -143,6 +144,8 @@ def prepare_blenderbot_inputs_dict( ...@@ -143,6 +144,8 @@ def prepare_blenderbot_inputs_dict(
decoder_input_ids, decoder_input_ids,
attention_mask=None, attention_mask=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
...@@ -154,11 +157,17 @@ def prepare_blenderbot_inputs_dict( ...@@ -154,11 +157,17 @@ def prepare_blenderbot_inputs_dict(
], ],
axis=-1, axis=-1,
) )
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
} }
...@@ -168,6 +177,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -168,6 +177,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_head_masking = True
def setUp(self): def setUp(self):
self.model_tester = TFBlenderbotModelTester(self) self.model_tester = TFBlenderbotModelTester(self)
......
...@@ -107,10 +107,11 @@ class TFBlenderbotSmallModelTester: ...@@ -107,10 +107,11 @@ class TFBlenderbotSmallModelTester:
input_ids = input_ids[:1, :] input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :] attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1 self.batch_size = 1
# first forward pass # first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple() output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1] past_key_values = past_key_values[1]
...@@ -143,6 +144,8 @@ def prepare_blenderbot_small_inputs_dict( ...@@ -143,6 +144,8 @@ def prepare_blenderbot_small_inputs_dict(
decoder_input_ids, decoder_input_ids,
attention_mask=None, attention_mask=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
...@@ -154,11 +157,17 @@ def prepare_blenderbot_small_inputs_dict( ...@@ -154,11 +157,17 @@ def prepare_blenderbot_small_inputs_dict(
], ],
axis=-1, axis=-1,
) )
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
} }
...@@ -170,6 +179,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -170,6 +179,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_head_masking = True
def setUp(self): def setUp(self):
self.model_tester = TFBlenderbotSmallModelTester(self) self.model_tester = TFBlenderbotSmallModelTester(self)
......
...@@ -440,6 +440,11 @@ class TFModelTesterMixin: ...@@ -440,6 +440,11 @@ class TFModelTesterMixin:
def test_train_pipeline_custom_model(self): def test_train_pipeline_custom_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# head_mask and decoder_head_mask has different shapes than other input args
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
if "decoder_head_mask" in inputs_dict:
del inputs_dict["decoder_head_mask"]
tf_main_layer_classes = set( tf_main_layer_classes = set(
module_member module_member
for model_class in self.all_model_classes for model_class in self.all_model_classes
...@@ -620,6 +625,75 @@ class TFModelTesterMixin: ...@@ -620,6 +625,75 @@ class TFModelTesterMixin:
self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(model.config.output_hidden_states, True)
check_encoder_attentions_output(outputs) check_encoder_attentions_output(outputs)
def test_headmasking(self):
if not self.test_head_masking:
return
random.Random().seed(42)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
random.Random().seed()
inputs_dict["output_attentions"] = True
config.output_hidden_states = True
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
# Prepare head_mask
def prepare_layer_head_mask(i, attention_heads, num_hidden_layers):
if i == 0:
return tf.concat(
(tf.zeros(1, dtype=tf.float32), tf.ones(attention_heads - 1, dtype=tf.float32)), 0
)
elif i == num_hidden_layers - 1:
return tf.concat(
(tf.zeros(attention_heads - 1, dtype=tf.float32), tf.ones(1, dtype=tf.float32)), 0
)
else:
return tf.ones(attention_heads, dtype=tf.float32)
head_mask = tf.stack(
[
prepare_layer_head_mask(i, config.num_attention_heads, config.num_hidden_layers)
for i in range(config.num_hidden_layers)
],
0,
)
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
inputs["head_mask"] = head_mask
if model.config.is_encoder_decoder:
signature = inspect.signature(model.call)
arg_names = [*signature.parameters.keys()]
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
inputs["decoder_head_mask"] = head_mask
outputs = model(**inputs, return_dict=True)
def check_attentions_validity(attentions):
# Remove Nan
for t in attentions:
self.assertLess(
(tf.math.reduce_sum(tf.cast(tf.math.is_nan(t), tf.float32))).numpy(), (tf.size(t) / 4).numpy()
) # Check we don't have more than 25% nans (arbitrary)
attentions = [
tf.where(tf.math.is_nan(t), 0.0, t) for t in attentions
] # remove them (the test is less complete)
self.assertAlmostEqual(tf.math.reduce_sum(attentions[0][..., 0, :, :]).numpy(), 0.0)
self.assertNotEqual(tf.math.reduce_sum(attentions[0][..., -1, :, :]).numpy(), 0.0)
if len(attentions) > 2: # encoder-decodere models have only 2 layers in each modules
self.assertNotEqual(tf.math.reduce_sum(attentions[1][..., 0, :, :]).numpy(), 0.0)
self.assertAlmostEqual(tf.math.reduce_sum(attentions[-1][..., -2, :, :]).numpy(), 0.0)
self.assertNotEqual(tf.math.reduce_sum(attentions[-1][..., -1, :, :]).numpy(), 0.0)
if model.config.is_encoder_decoder:
check_attentions_validity(outputs.encoder_attentions)
check_attentions_validity(outputs.decoder_attentions)
else:
check_attentions_validity(outputs.attentions)
def test_hidden_states_output(self): def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -173,6 +173,7 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -173,6 +173,7 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else () all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else ()
all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else () all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else ()
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFCTRLModelTester(self) self.model_tester = TFCTRLModelTester(self)
......
...@@ -183,6 +183,7 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -183,6 +183,7 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else None else None
) )
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFDistilBertModelTester(self) self.model_tester = TFDistilBertModelTester(self)
......
...@@ -205,6 +205,7 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -205,6 +205,7 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFElectraModelTester(self) self.model_tester = TFElectraModelTester(self)
......
...@@ -291,6 +291,7 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -291,6 +291,7 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = ( all_generative_model_classes = (
(TFFlaubertWithLMHeadModel,) if is_tf_available() else () (TFFlaubertWithLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable ) # TODO (PVP): Check other models whether language generation is also applicable
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFFlaubertModelTester(self) self.model_tester = TFFlaubertModelTester(self)
......
...@@ -338,6 +338,7 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -338,6 +338,7 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFFunnelModelTester(self) self.model_tester = TFFunnelModelTester(self)
...@@ -376,6 +377,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -376,6 +377,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(TFFunnelBaseModel, TFFunnelForMultipleChoice, TFFunnelForSequenceClassification) if is_tf_available() else () (TFFunnelBaseModel, TFFunnelForMultipleChoice, TFFunnelForSequenceClassification) if is_tf_available() else ()
) )
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFFunnelModelTester(self, base=True) self.model_tester = TFFunnelModelTester(self, base=True)
......
...@@ -332,6 +332,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -332,6 +332,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
else () else ()
) )
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else () all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFGPT2ModelTester(self) self.model_tester = TFGPT2ModelTester(self)
......
...@@ -187,6 +187,7 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -187,6 +187,7 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFLEDModelTester(self) self.model_tester = TFLEDModelTester(self)
......
...@@ -297,6 +297,7 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -297,6 +297,7 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFLongformerModelTester(self) self.model_tester = TFLongformerModelTester(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment