"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "332cb7104ab9db20b8c6abd04a517a4838a30a7d"
Unverified Commit e7381c45 authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Add head_mask and decoder_head_mask to TF LED (#9988)

* Add head masking to TF LED

* Add head_mask to Longformer + one doc piece to LED

* Fix integration tests
parent 77c0ce8c
......@@ -719,6 +719,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
(
hidden_states,
attention_mask,
layer_head_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
......@@ -794,6 +795,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
attn_probs,
)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
# apply dropout
attn_probs = self.dropout(attn_probs, training=training)
value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
......@@ -829,6 +838,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
attn_output=attn_output,
hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices,
layer_head_mask=layer_head_mask,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
......@@ -1271,6 +1281,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
attn_output,
hidden_states,
max_num_global_attn_indices,
layer_head_mask,
is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
......@@ -1336,6 +1347,20 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# compute global attn probs
global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1)
# apply layer head maskin
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
)
global_attn_probs_float = tf.reshape(
global_attn_probs_float, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
)
# dropout
global_attn_probs = self.global_dropout(global_attn_probs_float, training=training)
......@@ -1398,13 +1423,14 @@ class TFLongformerAttention(tf.keras.layers.Layer):
(
hidden_states,
attention_mask,
layer_head_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
) = inputs
self_outputs = self.self_attention(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn],
[hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn],
training=training,
)
attention_output = self.dense_output(self_outputs[0], hidden_states, training=training)
......@@ -1425,13 +1451,14 @@ class TFLongformerLayer(tf.keras.layers.Layer):
(
hidden_states,
attention_mask,
layer_head_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
) = inputs
attention_outputs = self.attention(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn],
[hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn],
training=training,
)
attention_output = attention_outputs[0]
......@@ -1469,7 +1496,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
all_hidden_states = () if output_hidden_states else None
all_attentions = all_global_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
for idx, layer_module in enumerate(self.layer):
if output_hidden_states:
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
......@@ -1478,6 +1505,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
[
hidden_states,
attention_mask,
head_mask[idx] if head_mask is not None else None,
is_index_masked,
is_index_global_attn,
is_global_attn,
......@@ -1558,6 +1586,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
self,
input_ids=None,
attention_mask=None,
head_mask=None,
global_attention_mask=None,
token_type_ids=None,
position_ids=None,
......@@ -1573,6 +1602,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -1649,6 +1679,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
padding_len=padding_len,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
......@@ -1842,6 +1873,12 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
global_attention_mask (:obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Mask to decide the attention given on each token, local attention or global attention. Tokens with global
attention attends to all other tokens, and all other tokens attend to them. This is important for
......@@ -1918,6 +1955,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
self,
input_ids=None,
attention_mask=None,
head_mask=None,
global_attention_mask=None,
token_type_ids=None,
position_ids=None,
......@@ -1933,6 +1971,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -1946,6 +1985,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
......@@ -2004,6 +2044,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
self,
input_ids=None,
attention_mask=None,
head_mask=None,
global_attention_mask=None,
token_type_ids=None,
position_ids=None,
......@@ -2026,6 +2067,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -2040,6 +2082,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
......@@ -2109,6 +2152,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
self,
input_ids=None,
attention_mask=None,
head_mask=None,
global_attention_mask=None,
token_type_ids=None,
position_ids=None,
......@@ -2136,6 +2180,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -2170,6 +2215,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
......@@ -2274,6 +2320,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
self,
input_ids=None,
attention_mask=None,
head_mask=None,
token_type_ids=None,
position_ids=None,
global_attention_mask=None,
......@@ -2290,6 +2337,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -2321,6 +2369,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
......@@ -2397,6 +2446,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
self,
input_ids=None,
attention_mask=None,
head_mask=None,
token_type_ids=None,
position_ids=None,
global_attention_mask=None,
......@@ -2419,6 +2469,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -2464,6 +2515,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
position_ids=flat_position_ids,
token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask,
head_mask=head_mask,
global_attention_mask=flat_global_attention_mask,
inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions,
......@@ -2547,6 +2599,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
self,
input_ids=None,
attention_mask=None,
head_mask=None,
token_type_ids=None,
position_ids=None,
global_attention_mask=None,
......@@ -2568,6 +2621,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -2582,6 +2636,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
......
......@@ -162,6 +162,8 @@ def prepare_led_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
......@@ -173,11 +175,17 @@ def prepare_led_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
}
......@@ -187,7 +195,6 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
def setUp(self):
self.model_tester = TFLEDModelTester(self)
......
......@@ -297,7 +297,6 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False
def setUp(self):
self.model_tester = TFLongformerModelTester(self)
......@@ -517,8 +516,10 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None])
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
layer_head_mask = None
output_hidden_states = layer(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn]
[hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn]
)[0]
expected_slice = tf.convert_to_tensor(
......@@ -549,8 +550,17 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
layer_head_mask = None
output_hidden_states = layer(
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn]
[
hidden_states,
-tf.math.abs(attention_mask),
layer_head_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
]
)[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
......@@ -584,8 +594,17 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
layer_head_mask = None
output_hidden_states, local_attentions, global_attentions = layer(
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn]
[
hidden_states,
-tf.math.abs(attention_mask),
layer_head_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
]
)
self.assertEqual(local_attentions.shape, (2, 4, 2, 8))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment