Unverified Commit 9a0399e1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix bart tests (#10060)

parent b01483fa
......@@ -42,7 +42,6 @@ if is_torch_available():
BartForSequenceClassification,
BartModel,
BartTokenizer,
BartTokenizerFast,
pipeline,
)
from transformers.models.bart.modeling_bart import BartDecoder, BartEncoder, shift_tokens_right
......@@ -566,10 +565,6 @@ class BartModelIntegrationTests(unittest.TestCase):
def default_tokenizer(self):
return BartTokenizer.from_pretrained("facebook/bart-large")
@cached_property
def default_tokenizer_fast(self):
return BartTokenizerFast.from_pretrained("facebook/bart-large")
@slow
def test_inference_no_head(self):
model = BartModel.from_pretrained("facebook/bart-large").to(torch_device)
......@@ -589,14 +584,14 @@ class BartModelIntegrationTests(unittest.TestCase):
pbase = pipeline(task="fill-mask", model="facebook/bart-base")
src_text = [" I went to the <mask>."]
results = [x["token_str"] for x in pbase(src_text)]
assert "Ġbathroom" in results
assert " bathroom" in results
@slow
def test_large_mask_filling(self):
plarge = pipeline(task="fill-mask", model="facebook/bart-large")
src_text = [" I went to the <mask>."]
results = [x["token_str"] for x in plarge(src_text)]
expected_results = ["Ġbathroom", "Ġgym", "Ġwrong", "Ġmovies", "Ġhospital"]
expected_results = [" bathroom", " gym", " wrong", " movies", " hospital"]
self.assertListEqual(results, expected_results)
@slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment