"docs/vscode:/vscode.git/clone" did not exist on "1f5ea9e04a27171a5034a61999bc6359d19fe4ef"
Unverified Commit 8d79e5ca authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Fix head masking for TFT5 (#9877)



* Fix head_mask and decoder_head_mask in TFT5 models

* Enable test_headmasking both fot TFT5 tester
and TFT5EncoderOnly tester
Co-authored-by: default avatarpatrickvonplaten <patrick.v.platen@gmail.com>
parent 4b919657
......@@ -344,7 +344,12 @@ class TFT5Attention(tf.keras.layers.Layer):
# Mask heads if we want to
if layer_head_mask is not None:
weights = weights * layer_head_mask
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.n_heads],
message=f"Head mask for a single layer should be of size {(self.n_heads)}, but is {shape_list(layer_head_mask)}",
)
weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * weights
attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head)
......@@ -711,10 +716,6 @@ class TFT5MainLayer(tf.keras.layers.Layer):
else:
encoder_extended_attention_mask = None
assert inputs["head_mask"] is None, "Head mask not supported"
inputs["head_mask"] = [None] * self.num_hidden_layers
assert inputs["encoder_head_mask"] is None, "Encoder head mask not supported"
inputs["encoder_head_mask"] = [None] * self.num_hidden_layers
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
all_hidden_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None
......@@ -723,7 +724,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
hidden_states = self.dropout(inputs["inputs_embeds"], training=inputs["training"])
for i, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])):
for idx, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])):
if inputs["output_hidden_states"]:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
......@@ -733,8 +734,10 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=inputs["head_mask"][i],
encoder_layer_head_mask=inputs["encoder_head_mask"][i],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value,
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
......@@ -1057,7 +1060,7 @@ T5_ENCODER_INPUTS_DOCSTRING = r"""
behaviors between training and evaluation).
"""
__HEAD_MASK_WARNING_MSG = """
_HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = tf.ones((num_layers,
......@@ -1133,7 +1136,7 @@ class TFT5Model(TFT5PreTrainedModel):
"""
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
inputs = input_processing(
......@@ -1327,7 +1330,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
"""
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
inputs = input_processing(
......
......@@ -248,7 +248,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = True
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
test_head_masking = False
test_onnx = False
def setUp(self):
......@@ -427,7 +426,6 @@ class TFT5EncoderOnlyModelTester:
class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = False
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
test_head_masking = False
test_onnx = False
def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment