Unverified Commit e11d923b authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Fix pegasus-xsum integration test (#6726)

parent 7e6397a7
...@@ -20,8 +20,8 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest): ...@@ -20,8 +20,8 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
checkpoint_name = "google/pegasus-xsum" checkpoint_name = "google/pegasus-xsum"
src_text = [PGE_ARTICLE, XSUM_ENTRY_LONGER] src_text = [PGE_ARTICLE, XSUM_ENTRY_LONGER]
tgt_text = [ tgt_text = [
"California's largest electricity provider has turned off power to tens of thousands of customers.", "California's largest electricity provider has turned off power to hundreds of thousands of customers.",
"N-Dubz have revealed they weren't expecting to get four nominations at this year's Mobo Awards.", "N-Dubz have said they were surprised to get four nominations for this year's Mobo Awards.",
] ]
@cached_property @cached_property
...@@ -37,7 +37,7 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest): ...@@ -37,7 +37,7 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
assert inputs.input_ids.shape == (2, 421) assert inputs.input_ids.shape == (2, 421)
translated_tokens = self.model.generate(**inputs) translated_tokens = self.model.generate(**inputs)
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
self.assertEqual(self.tgt_text, decoded) assert self.tgt_text == decoded
if "cuda" not in torch_device: if "cuda" not in torch_device:
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment