"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e53af030c032adabe83dfd8fd7c7576bd76dcf83"
Unverified Commit 8d79e5ca authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Fix head masking for TFT5 (#9877)



* Fix head_mask and decoder_head_mask in TFT5 models

* Enable test_headmasking both fot TFT5 tester
and TFT5EncoderOnly tester
Co-authored-by: default avatarpatrickvonplaten <patrick.v.platen@gmail.com>
parent 4b919657
...@@ -344,7 +344,12 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -344,7 +344,12 @@ class TFT5Attention(tf.keras.layers.Layer):
# Mask heads if we want to # Mask heads if we want to
if layer_head_mask is not None: if layer_head_mask is not None:
weights = weights * layer_head_mask tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.n_heads],
message=f"Head mask for a single layer should be of size {(self.n_heads)}, but is {shape_list(layer_head_mask)}",
)
weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * weights
attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head) attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head)
...@@ -711,10 +716,6 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -711,10 +716,6 @@ class TFT5MainLayer(tf.keras.layers.Layer):
else: else:
encoder_extended_attention_mask = None encoder_extended_attention_mask = None
assert inputs["head_mask"] is None, "Head mask not supported"
inputs["head_mask"] = [None] * self.num_hidden_layers
assert inputs["encoder_head_mask"] is None, "Encoder head mask not supported"
inputs["encoder_head_mask"] = [None] * self.num_hidden_layers
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
all_hidden_states = () if inputs["output_hidden_states"] else None all_hidden_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if inputs["output_attentions"] else None
...@@ -723,7 +724,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -723,7 +724,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
hidden_states = self.dropout(inputs["inputs_embeds"], training=inputs["training"]) hidden_states = self.dropout(inputs["inputs_embeds"], training=inputs["training"])
for i, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])): for idx, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -733,8 +734,10 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -733,8 +734,10 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=inputs["head_mask"][i], layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][i], encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
...@@ -1057,7 +1060,7 @@ T5_ENCODER_INPUTS_DOCSTRING = r""" ...@@ -1057,7 +1060,7 @@ T5_ENCODER_INPUTS_DOCSTRING = r"""
behaviors between training and evaluation). behaviors between training and evaluation).
""" """
__HEAD_MASK_WARNING_MSG = """ _HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = tf.ones((num_layers, If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = tf.ones((num_layers,
...@@ -1133,7 +1136,7 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1133,7 +1136,7 @@ class TFT5Model(TFT5PreTrainedModel):
""" """
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None: if head_mask is not None and decoder_head_mask is None:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask decoder_head_mask = head_mask
inputs = input_processing( inputs = input_processing(
...@@ -1327,7 +1330,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1327,7 +1330,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
""" """
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None: if head_mask is not None and decoder_head_mask is None:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask decoder_head_mask = head_mask
inputs = input_processing( inputs = input_processing(
......
...@@ -248,7 +248,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -248,7 +248,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = True is_encoder_decoder = True
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else () all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
test_head_masking = False
test_onnx = False test_onnx = False
def setUp(self): def setUp(self):
...@@ -427,7 +426,6 @@ class TFT5EncoderOnlyModelTester: ...@@ -427,7 +426,6 @@ class TFT5EncoderOnlyModelTester:
class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase): class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = False is_encoder_decoder = False
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else () all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
test_head_masking = False
test_onnx = False test_onnx = False
def setUp(self): def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment