Unverified Commit b2c1a447 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[BART] Delete redundant unit test (#3302)

parent b2028cc2
...@@ -381,7 +381,7 @@ TOLERANCE = 1e-4 ...@@ -381,7 +381,7 @@ TOLERANCE = 1e-4
@require_torch @require_torch
class BartModelIntegrationTest(unittest.TestCase): class BartModelIntegrationTests(unittest.TestCase):
@slow @slow
def test_inference_no_head(self): def test_inference_no_head(self):
model = BartModel.from_pretrained("bart-large").to(torch_device) model = BartModel.from_pretrained("bart-large").to(torch_device)
...@@ -431,25 +431,7 @@ class BartModelIntegrationTest(unittest.TestCase): ...@@ -431,25 +431,7 @@ class BartModelIntegrationTest(unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
@slow @slow
def test_cnn_summarization_same_as_fairseq_easy(self): def test_cnn_summarization_same_as_fairseq(self):
hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
tok = BartTokenizer.from_pretrained("bart-large")
text = " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian"
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
extra_len = 20
gen_tokens = hf.generate(
tokens,
num_beams=4,
max_length=extra_len + 2,
do_sample=False,
decoder_start_token_id=hf.config.eos_token_ids[0],
) # repetition_penalty=10.,
expected_result = "<s>The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday."
generated = [tok.decode(g,) for g in gen_tokens]
self.assertEqual(expected_result, generated[0])
@slow
def test_cnn_summarization_same_as_fairseq_hard(self):
hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device) hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
tok = BartTokenizer.from_pretrained("bart-large") tok = BartTokenizer.from_pretrained("bart-large")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment