"...resnet50_tensorflow.git" did not exist on "4b617781f28e6cccd5b081856e7797d4c5b9acd4"
Unverified Commit b3b9f99e authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix test_t5_decoder_model_past_large_inputs (#17320)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 6da76b9c
...@@ -295,6 +295,13 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -295,6 +295,13 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
def test_t5_decoder_model_past_large_inputs(self): def test_t5_decoder_model_past_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
# `create_and_check_t5_decoder_model_past_large_inputs` has special inputs:
# (config, input_ids, decoder_input_ids, attention_mask)
# and we have to prepare it correctly here.
config, input_ids, input_mask, token_labels = config_and_inputs
config_and_inputs = (config, input_ids, None, input_mask)
self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs) self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs)
def test_t5_model_xla_generate_fast(self): def test_t5_model_xla_generate_fast(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment