Unverified Commit 1867d9a8 authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Add head_mask/decoder_head_mask for TF BART models (#9639)

* Add head_mask/decoder_head_mask for TF BART models

* Add head_mask and decoder_head_mask input arguments for TF BART-based
models as a TF counterpart to the PR #9569

* Add test_headmasking functionality to tests/test_modeling_tf_common.py

* TODO: Add a test to verify that we can get a gradient back for
importance score computation

* Remove redundant #TODO note

Remove redundant #TODO note from tests/test_modeling_tf_common.py

* Fix assertions

* Make style

* Fix ...Model input args and adjust one new test

* Add back head_mask and decoder_head_mask to BART-based ...Model
after the last commit

* Remove head_mask ande decoder_head_mask from input_dict
in TF test_train_pipeline_custom_model as these two have different
shape than other input args (Necessary for passing this test)

* Revert adding global_rng in test_modeling_tf_common.py
parent cb73ab5a
...@@ -361,6 +361,7 @@ class TFLxmertModelTester(object): ...@@ -361,6 +361,7 @@ class TFLxmertModelTester(object):
class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase): class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFLxmertModel, TFLxmertForPreTraining) if is_tf_available() else () all_model_classes = (TFLxmertModel, TFLxmertForPreTraining) if is_tf_available() else ()
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFLxmertModelTester(self) self.model_tester = TFLxmertModelTester(self)
......
...@@ -109,10 +109,11 @@ class TFMarianModelTester: ...@@ -109,10 +109,11 @@ class TFMarianModelTester:
input_ids = input_ids[:1, :] input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :] attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1 self.batch_size = 1
# first forward pass # first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple() output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1] past_key_values = past_key_values[1]
...@@ -145,6 +146,8 @@ def prepare_marian_inputs_dict( ...@@ -145,6 +146,8 @@ def prepare_marian_inputs_dict(
decoder_input_ids, decoder_input_ids,
attention_mask=None, attention_mask=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
...@@ -156,11 +159,17 @@ def prepare_marian_inputs_dict( ...@@ -156,11 +159,17 @@ def prepare_marian_inputs_dict(
], ],
axis=-1, axis=-1,
) )
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
} }
...@@ -170,6 +179,7 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -170,6 +179,7 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else () all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_head_masking = True
def setUp(self): def setUp(self):
self.model_tester = TFMarianModelTester(self) self.model_tester = TFMarianModelTester(self)
......
...@@ -106,10 +106,11 @@ class TFMBartModelTester: ...@@ -106,10 +106,11 @@ class TFMBartModelTester:
input_ids = input_ids[:1, :] input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :] attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1 self.batch_size = 1
# first forward pass # first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple() output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1] past_key_values = past_key_values[1]
...@@ -147,6 +148,8 @@ def prepare_mbart_inputs_dict( ...@@ -147,6 +148,8 @@ def prepare_mbart_inputs_dict(
decoder_input_ids, decoder_input_ids,
attention_mask=None, attention_mask=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
...@@ -158,11 +161,17 @@ def prepare_mbart_inputs_dict( ...@@ -158,11 +161,17 @@ def prepare_mbart_inputs_dict(
], ],
axis=-1, axis=-1,
) )
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": head_mask,
} }
...@@ -172,6 +181,7 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -172,6 +181,7 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_head_masking = True
def setUp(self): def setUp(self):
self.model_tester = TFMBartModelTester(self) self.model_tester = TFMBartModelTester(self)
......
...@@ -55,6 +55,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -55,6 +55,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
test_head_masking = False
class TFMobileBertModelTester(object): class TFMobileBertModelTester(object):
def __init__( def __init__(
......
...@@ -198,6 +198,7 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -198,6 +198,7 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFMPNetModelTester(self) self.model_tester = TFMPNetModelTester(self)
......
...@@ -202,6 +202,7 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -202,6 +202,7 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = ( all_generative_model_classes = (
(TFOpenAIGPTLMHeadModel,) if is_tf_available() else () (TFOpenAIGPTLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly ) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFOpenAIGPTModelTester(self) self.model_tester = TFOpenAIGPTModelTester(self)
......
...@@ -107,10 +107,11 @@ class TFPegasusModelTester: ...@@ -107,10 +107,11 @@ class TFPegasusModelTester:
input_ids = input_ids[:1, :] input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :] attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1 self.batch_size = 1
# first forward pass # first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple() output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1] past_key_values = past_key_values[1]
...@@ -143,6 +144,8 @@ def prepare_pegasus_inputs_dict( ...@@ -143,6 +144,8 @@ def prepare_pegasus_inputs_dict(
decoder_input_ids, decoder_input_ids,
attention_mask=None, attention_mask=None,
decoder_attention_mask=None, decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
): ):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
...@@ -154,11 +157,17 @@ def prepare_pegasus_inputs_dict( ...@@ -154,11 +157,17 @@ def prepare_pegasus_inputs_dict(
], ],
axis=-1, axis=-1,
) )
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
} }
...@@ -168,6 +177,7 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -168,6 +177,7 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_head_masking = True
def setUp(self): def setUp(self):
self.model_tester = TFPegasusModelTester(self) self.model_tester = TFPegasusModelTester(self)
......
...@@ -185,6 +185,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -185,6 +185,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available() if is_tf_available()
else () else ()
) )
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFRobertaModelTester(self) self.model_tester = TFRobertaModelTester(self)
......
...@@ -248,6 +248,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -248,6 +248,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = True is_encoder_decoder = True
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else () all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFT5ModelTester(self) self.model_tester = TFT5ModelTester(self)
...@@ -417,6 +418,7 @@ class TFT5EncoderOnlyModelTester: ...@@ -417,6 +418,7 @@ class TFT5EncoderOnlyModelTester:
class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase): class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = False is_encoder_decoder = False
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else () all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFT5EncoderOnlyModelTester(self) self.model_tester = TFT5EncoderOnlyModelTester(self)
......
...@@ -163,6 +163,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -163,6 +163,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = () if is_tf_available() else () all_generative_model_classes = () if is_tf_available() else ()
# TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented # TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented
test_resize_embeddings = False test_resize_embeddings = False
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFTransfoXLModelTester(self) self.model_tester = TFTransfoXLModelTester(self)
......
...@@ -293,6 +293,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -293,6 +293,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = ( all_generative_model_classes = (
(TFXLMWithLMHeadModel,) if is_tf_available() else () (TFXLMWithLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable ) # TODO (PVP): Check other models whether language generation is also applicable
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFXLMModelTester(self) self.model_tester = TFXLMModelTester(self)
......
...@@ -347,6 +347,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -347,6 +347,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = ( all_generative_model_classes = (
(TFXLNetLMHeadModel,) if is_tf_available() else () (TFXLNetLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable ) # TODO (PVP): Check other models whether language generation is also applicable
test_head_masking = False
def setUp(self): def setUp(self):
self.model_tester = TFXLNetModelTester(self) self.model_tester = TFXLNetModelTester(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment