Unverified Commit 449df41f authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `TFEncoderDecoder` tests (#21301)



remove max_length=None
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 857bad6e
...@@ -785,7 +785,7 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase): ...@@ -785,7 +785,7 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
EXPECTED_SUMMARY_STUDENTS = """sae was founded in 1856, five years before the civil war. the fraternity has had to work hard to change recently. the university of oklahoma president says the university's affiliation with the fraternity is permanently done. the sae has had a string of members in recent months.""" EXPECTED_SUMMARY_STUDENTS = """sae was founded in 1856, five years before the civil war. the fraternity has had to work hard to change recently. the university of oklahoma president says the university's affiliation with the fraternity is permanently done. the sae has had a string of members in recent months."""
input_dict = tokenizer(ARTICLE_STUDENTS, return_tensors="tf") input_dict = tokenizer(ARTICLE_STUDENTS, return_tensors="tf")
output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist() output_ids = model.generate(input_ids=input_dict["input_ids"]).numpy().tolist()
summary = tokenizer.batch_decode(output_ids, skip_special_tokens=True) summary = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS]) self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
...@@ -793,7 +793,7 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase): ...@@ -793,7 +793,7 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
# Test with the TF checkpoint # Test with the TF checkpoint
model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16") model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16")
output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist() output_ids = model.generate(input_ids=input_dict["input_ids"]).numpy().tolist()
summary = tokenizer.batch_decode(output_ids, skip_special_tokens=True) summary = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS]) self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
...@@ -887,7 +887,7 @@ class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase): ...@@ -887,7 +887,7 @@ class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
EXPECTED_SUMMARY_STUDENTS = """SAS Alpha Epsilon suspended the students, but university president says it's permanent.\nThe fraternity has had to deal with a string of student deaths since 2010.\nSAS has more than 200,000 members, many of whom are students.\nA student died while being forced into excessive alcohol consumption.""" EXPECTED_SUMMARY_STUDENTS = """SAS Alpha Epsilon suspended the students, but university president says it's permanent.\nThe fraternity has had to deal with a string of student deaths since 2010.\nSAS has more than 200,000 members, many of whom are students.\nA student died while being forced into excessive alcohol consumption."""
input_dict = tokenizer_in(ARTICLE_STUDENTS, return_tensors="tf") input_dict = tokenizer_in(ARTICLE_STUDENTS, return_tensors="tf")
output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist() output_ids = model.generate(input_ids=input_dict["input_ids"]).numpy().tolist()
summary = tokenizer_out.batch_decode(output_ids, skip_special_tokens=True) summary = tokenizer_out.batch_decode(output_ids, skip_special_tokens=True)
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS]) self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment