Unverified Commit 62f5ae68 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Seq2Seq] Fix a couple of bugs and clean examples (#7474)



* clean T5

* fix t5 tests

* fix index typo

* fix tf common test

* fix examples

* change positional ordering for Bart and FSTM

* add signature test

* clean docs and add tests

* add docs to encoder decoder

* clean docs

* correct two doc strings

* remove sig test for TF Elektra & Funnel

* fix tf t5 slow tests

* fix input_ids to inputs in tf

* Update src/transformers/modeling_bart.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_bart.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* implement lysandre results

* make style

* fix encoder decoder typo

* fix tf slow tests

* fix slow tests

* renaming

* remove unused input
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent a42f62d3
...@@ -71,7 +71,7 @@ class ModelTester: ...@@ -71,7 +71,7 @@ class ModelTester:
# hack needed for modeling_common tests - despite not really having this attribute in this model # hack needed for modeling_common tests - despite not really having this attribute in this model
self.vocab_size = self.src_vocab_size self.vocab_size = self.src_vocab_size
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.src_vocab_size).clamp( input_ids = ids_tensor([self.batch_size, self.seq_length], self.src_vocab_size).clamp(
3, 3,
) )
...@@ -99,6 +99,13 @@ class ModelTester: ...@@ -99,6 +99,13 @@ class ModelTester:
inputs_dict = prepare_fsmt_inputs_dict(config, input_ids) inputs_dict = prepare_fsmt_inputs_dict(config, input_ids)
return config, inputs_dict return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"]
inputs_dict["decoder_attention_mask"] = inputs_dict["attention_mask"]
inputs_dict["use_cache"] = False
return config, inputs_dict
def prepare_fsmt_inputs_dict( def prepare_fsmt_inputs_dict(
config, config,
...@@ -142,7 +149,7 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -142,7 +149,7 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
# XXX: override test_model_common_attributes / different Embedding type # XXX: override test_model_common_attributes / different Embedding type
def test_model_common_attributes(self): def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
...@@ -152,7 +159,7 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -152,7 +159,7 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
self.assertTrue(x is None or isinstance(x, torch.nn.modules.sparse.Embedding)) self.assertTrue(x is None or isinstance(x, torch.nn.modules.sparse.Embedding))
def test_initialization_more(self): def test_initialization_more(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs()
model = FSMTModel(config) model = FSMTModel(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -170,7 +177,7 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -170,7 +177,7 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
# self.assertAlmostEqual(torch.std(model.encoder.embed_positions.weights).item(), config.init_std, 2) # self.assertAlmostEqual(torch.std(model.encoder.embed_positions.weights).item(), config.init_std, 2)
def test_advanced_inputs(self): def test_advanced_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs()
config.use_cache = False config.use_cache = False
inputs_dict["input_ids"][:, -2:] = config.pad_token_id inputs_dict["input_ids"][:, -2:] = config.pad_token_id
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs( decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs(
...@@ -200,7 +207,7 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -200,7 +207,7 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
_assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask) _assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
def test_save_load_strict(self): def test_save_load_strict(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
...@@ -210,7 +217,7 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -210,7 +217,7 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
self.assertEqual(info["missing_keys"], []) self.assertEqual(info["missing_keys"], [])
def test_save_load_no_save_keys(self): def test_save_load_no_save_keys(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
......
...@@ -261,6 +261,38 @@ class GPT2ModelTester: ...@@ -261,6 +261,38 @@ class GPT2ModelTester:
# test that outputs are equal for slice # test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_gpt2_model_past_large_inputs(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
):
model = GPT2Model(config=config)
model.to(torch_device)
model.eval()
# first forward pass
outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True)
output, past = outputs.to_tuple()
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_token_types = ids_tensor([self.batch_size, 3], self.type_vocab_size)
# append to next input_ids and token_type_ids
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1)
output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"]
output_from_past = model(next_tokens, token_type_ids=next_token_types, past=past)["last_hidden_state"]
self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1])
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = GPT2LMHeadModel(config) model = GPT2LMHeadModel(config)
model.to(torch_device) model.to(torch_device)
...@@ -357,6 +389,10 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -357,6 +389,10 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_model_attention_mask_past(*config_and_inputs) self.model_tester.create_and_check_gpt2_model_attention_mask_past(*config_and_inputs)
def test_gpt2_model_past_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_model_past_large_inputs(*config_and_inputs)
def test_gpt2_lm_head_model(self): def test_gpt2_lm_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_lm_head_model(*config_and_inputs) self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
......
...@@ -235,7 +235,7 @@ class T5ModelTester: ...@@ -235,7 +235,7 @@ class T5ModelTester:
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
output, past_key_value_states = outputs.to_tuple() output, past_key_values = outputs.to_tuple()
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
...@@ -244,7 +244,7 @@ class T5ModelTester: ...@@ -244,7 +244,7 @@ class T5ModelTester:
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past = model(next_input_ids)["last_hidden_state"] output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)["last_hidden_state"] output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
# select random slice # select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
...@@ -274,7 +274,7 @@ class T5ModelTester: ...@@ -274,7 +274,7 @@ class T5ModelTester:
attn_mask[:, half_seq_length:] = 0 attn_mask[:, half_seq_length:] = 0
# first forward pass # first forward pass
output, past_key_value_states = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple() output, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple()
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
...@@ -293,7 +293,7 @@ class T5ModelTester: ...@@ -293,7 +293,7 @@ class T5ModelTester:
# get two different outputs # get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask)[ output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[
"last_hidden_state" "last_hidden_state"
] ]
...@@ -305,7 +305,41 @@ class T5ModelTester: ...@@ -305,7 +305,41 @@ class T5ModelTester:
# test that outputs are equal for slice # test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_generate_with_past_key_value_states( def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = T5Model(config=config).get_decoder().to(torch_device).eval()
# first forward pass
outputs = model(input_ids, use_cache=True)
output, past_key_values = outputs.to_tuple()
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_generate_with_past_key_values(
self, self,
config, config,
input_ids, input_ids,
...@@ -439,7 +473,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -439,7 +473,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
test_pruning = False test_pruning = False
test_torchscript = False test_torchscript = True
test_resize_embeddings = False test_resize_embeddings = False
is_encoder_decoder = True is_encoder_decoder = True
...@@ -470,9 +504,13 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -470,9 +504,13 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
def test_generate_with_past_key_value_states(self): def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
def test_generate_with_past_key_values(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_generate_with_past_key_value_states(*config_and_inputs) self.model_tester.create_and_check_generate_with_past_key_values(*config_and_inputs)
def test_encoder_decoder_shared_weights(self): def test_encoder_decoder_shared_weights(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
...@@ -495,10 +533,11 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -495,10 +533,11 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
torch.onnx.export( torch.onnx.export(
model, model,
config_and_inputs[1], (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
f"{tmpdirname}/t5_test.onnx", f"{tmpdirname}/t5_test.onnx",
export_params=True, export_params=True,
opset_version=9, opset_version=9,
input_names=["input_ids", "decoder_input_ids"],
) )
...@@ -527,7 +566,7 @@ class T5ModelIntegrationTests(unittest.TestCase): ...@@ -527,7 +566,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.' ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
expected_summaries = [ expected_summaries = [
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," the magazine says .', 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video at the crash site . "one can hear cries of \'My God\' in several languages," one magazine says .',
"the Palestinians become the 123rd member of the international criminal court . the accession was marked by a ceremony at the Hague, where the court is based . as members of the court, Palestinians may be subject to counter-charges as well .", "the Palestinians become the 123rd member of the international criminal court . the accession was marked by a ceremony at the Hague, where the court is based . as members of the court, Palestinians may be subject to counter-charges as well .",
"the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and implement a rigorous inspection regime .", "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and implement a rigorous inspection regime .",
'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .', 'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .',
...@@ -604,13 +643,6 @@ class T5ModelIntegrationTests(unittest.TestCase): ...@@ -604,13 +643,6 @@ class T5ModelIntegrationTests(unittest.TestCase):
"sous forme " "sous forme "
"de points bleus." "de points bleus."
) )
# expected_translation = (
# "Cette section d'images provenant de l'enregistrement infrarouge effectué par le "
# "télescope Spitzer montre un « portrait familial » de générations innombrables de "
# "étoiles : les plus anciennes sont observées sous forme de pointes bleues, "
# "alors que les « nouveau-nés » de couleur rose dans la salle des accouchements doivent "
# "être plus difficiles "
# )
self.assertEqual(translation, new_truncated_translation) self.assertEqual(translation, new_truncated_translation)
......
...@@ -136,6 +136,29 @@ class TFModelTesterMixin: ...@@ -136,6 +136,29 @@ class TFModelTesterMixin:
outputs = run_in_graph_mode() outputs = run_in_graph_mode()
self.assertIsNotNone(outputs) self.assertIsNotNone(outputs)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
if model.config.is_encoder_decoder:
expected_arg_names = [
"inputs",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"encoder_outputs",
]
self.assertListEqual(arg_names[:5], expected_arg_names)
else:
expected_arg_names = ["inputs"]
self.assertListEqual(arg_names[:1], expected_arg_names)
@slow @slow
def test_saved_model_with_hidden_states_output(self): def test_saved_model_with_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -152,7 +175,12 @@ class TFModelTesterMixin: ...@@ -152,7 +175,12 @@ class TFModelTesterMixin:
tf.saved_model.save(model, tmpdirname) tf.saved_model.save(model, tmpdirname)
model = tf.keras.models.load_model(tmpdirname) model = tf.keras.models.load_model(tmpdirname)
outputs = model(inputs_dict) outputs = model(inputs_dict)
output = outputs[list(outputs.keys())[-1]] if isinstance(outputs, dict) else outputs[-1]
if self.is_encoder_decoder:
output = outputs["encoder_hidden_states"] if isinstance(outputs, dict) else outputs[-1]
else:
output = outputs["hidden_states"] if isinstance(outputs, dict) else outputs[-1]
hidden_states = [t.numpy() for t in output] hidden_states = [t.numpy() for t in output]
self.assertEqual(len(outputs), num_out) self.assertEqual(len(outputs), num_out)
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1) self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
...@@ -185,7 +213,12 @@ class TFModelTesterMixin: ...@@ -185,7 +213,12 @@ class TFModelTesterMixin:
tf.saved_model.save(model, tmpdirname) tf.saved_model.save(model, tmpdirname)
model = tf.keras.models.load_model(tmpdirname) model = tf.keras.models.load_model(tmpdirname)
outputs = model(inputs_dict) outputs = model(inputs_dict)
output = outputs[list(outputs.keys())[-1]] if isinstance(outputs, dict) else outputs[-1]
if self.is_encoder_decoder:
output = outputs["encoder_attentions"] if isinstance(outputs, dict) else outputs[-1]
else:
output = outputs["attentions"] if isinstance(outputs, dict) else outputs[-1]
attentions = [t.numpy() for t in output] attentions = [t.numpy() for t in output]
self.assertEqual(len(outputs), num_out) self.assertEqual(len(outputs), num_out)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
......
...@@ -211,6 +211,36 @@ class TFGPT2ModelTester: ...@@ -211,6 +211,36 @@ class TFGPT2ModelTester:
# test that outputs are equal for slice # test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-12) tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-12)
def create_and_check_gpt2_model_past_large_inputs(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
):
model = TFGPT2Model(config=config)
# first forward pass
outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True)
output, past = outputs.to_tuple()
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_token_types = ids_tensor((self.batch_size, 3), self.type_vocab_size)
# append to next input_ids and token_type_ids
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
next_token_type_ids = tf.concat([token_type_ids, next_token_types], axis=-1)
output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"]
output_from_past = model(next_tokens, token_type_ids=next_token_types, past=past)["last_hidden_state"]
self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1])
# select random slice
random_slice_idx = int(ids_tensor((1,), shape_list(output_from_past)[-1]))
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = TFGPT2LMHeadModel(config=config) model = TFGPT2LMHeadModel(config=config)
inputs = { inputs = {
...@@ -290,6 +320,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -290,6 +320,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_model_attention_mask_past(*config_and_inputs) self.model_tester.create_and_check_gpt2_model_attention_mask_past(*config_and_inputs)
def test_gpt2_model_past_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_model_past_large_inputs(*config_and_inputs)
def test_gpt2_lm_head(self): def test_gpt2_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_lm_head(*config_and_inputs) self.model_tester.create_and_check_gpt2_lm_head(*config_and_inputs)
......
This diff is collapsed.
...@@ -511,7 +511,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -511,7 +511,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs) self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)
def test_xlnet_base_model_use_cache(self): def test_xlnet_base_model_use_cache(self):
# checking that in auto-regressive mode, `use_cache` gives the same results # checking that in auto-regressive mode, :obj:`use_cache` gives the same results
self.model_tester.set_seed() self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_model_use_cache(*config_and_inputs) self.model_tester.create_and_check_xlnet_model_use_cache(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment