Unverified Commit 9a0399e1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix bart tests (#10060)

parent b01483fa
...@@ -42,7 +42,6 @@ if is_torch_available(): ...@@ -42,7 +42,6 @@ if is_torch_available():
BartForSequenceClassification, BartForSequenceClassification,
BartModel, BartModel,
BartTokenizer, BartTokenizer,
BartTokenizerFast,
pipeline, pipeline,
) )
from transformers.models.bart.modeling_bart import BartDecoder, BartEncoder, shift_tokens_right from transformers.models.bart.modeling_bart import BartDecoder, BartEncoder, shift_tokens_right
...@@ -566,10 +565,6 @@ class BartModelIntegrationTests(unittest.TestCase): ...@@ -566,10 +565,6 @@ class BartModelIntegrationTests(unittest.TestCase):
def default_tokenizer(self): def default_tokenizer(self):
return BartTokenizer.from_pretrained("facebook/bart-large") return BartTokenizer.from_pretrained("facebook/bart-large")
@cached_property
def default_tokenizer_fast(self):
return BartTokenizerFast.from_pretrained("facebook/bart-large")
@slow @slow
def test_inference_no_head(self): def test_inference_no_head(self):
model = BartModel.from_pretrained("facebook/bart-large").to(torch_device) model = BartModel.from_pretrained("facebook/bart-large").to(torch_device)
...@@ -589,14 +584,14 @@ class BartModelIntegrationTests(unittest.TestCase): ...@@ -589,14 +584,14 @@ class BartModelIntegrationTests(unittest.TestCase):
pbase = pipeline(task="fill-mask", model="facebook/bart-base") pbase = pipeline(task="fill-mask", model="facebook/bart-base")
src_text = [" I went to the <mask>."] src_text = [" I went to the <mask>."]
results = [x["token_str"] for x in pbase(src_text)] results = [x["token_str"] for x in pbase(src_text)]
assert "Ġbathroom" in results assert " bathroom" in results
@slow @slow
def test_large_mask_filling(self): def test_large_mask_filling(self):
plarge = pipeline(task="fill-mask", model="facebook/bart-large") plarge = pipeline(task="fill-mask", model="facebook/bart-large")
src_text = [" I went to the <mask>."] src_text = [" I went to the <mask>."]
results = [x["token_str"] for x in plarge(src_text)] results = [x["token_str"] for x in plarge(src_text)]
expected_results = ["Ġbathroom", "Ġgym", "Ġwrong", "Ġmovies", "Ġhospital"] expected_results = [" bathroom", " gym", " wrong", " movies", " hospital"]
self.assertListEqual(results, expected_results) self.assertListEqual(results, expected_results)
@slow @slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment