"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1429b920d44d610eaa0a6f48de43853da52e9c03"
Unverified Commit b86a71ea authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[tests] fix slow bart cnn test, faster marian tests (#7888)

parent ba8c4d0a
...@@ -594,7 +594,9 @@ class BartModelIntegrationTests(unittest.TestCase): ...@@ -594,7 +594,9 @@ class BartModelIntegrationTests(unittest.TestCase):
"Bronx on Friday. If convicted, she faces up to four years in prison.", "Bronx on Friday. If convicted, she faces up to four years in prison.",
] ]
generated_summaries = [tok.batch_decode(hypotheses_batch.tolist())] generated_summaries = tok.batch_decode(
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
)
assert generated_summaries == EXPECTED assert generated_summaries == EXPECTED
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import unittest import unittest
from transformers import is_torch_available from transformers import AutoConfig, AutoTokenizer, MarianConfig, MarianTokenizer, is_torch_available
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.hf_api import HfApi from transformers.hf_api import HfApi
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
...@@ -25,14 +25,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers ...@@ -25,14 +25,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import ( from transformers import AutoModelWithLMHead, MarianMTModel
AutoConfig,
AutoModelWithLMHead,
AutoTokenizer,
MarianConfig,
MarianMTModel,
MarianTokenizer,
)
from transformers.convert_marian_to_pytorch import ( from transformers.convert_marian_to_pytorch import (
ORG_NAME, ORG_NAME,
convert_hf_name_to_opus_name, convert_hf_name_to_opus_name,
...@@ -79,10 +72,16 @@ class MarianIntegrationTest(unittest.TestCase): ...@@ -79,10 +72,16 @@ class MarianIntegrationTest(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
cls.model_name = f"Helsinki-NLP/opus-mt-{cls.src}-{cls.tgt}" cls.model_name = f"Helsinki-NLP/opus-mt-{cls.src}-{cls.tgt}"
cls.tokenizer: MarianTokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.eos_token_id = cls.tokenizer.eos_token_id
return cls return cls
@cached_property
def tokenizer(self) -> MarianTokenizer:
return AutoTokenizer.from_pretrained(self.model_name)
@property
def eos_token_id(self) -> int:
return self.tokenizer.eos_token_id
@cached_property @cached_property
def model(self): def model(self):
model: MarianMTModel = AutoModelWithLMHead.from_pretrained(self.model_name).to(torch_device) model: MarianMTModel = AutoModelWithLMHead.from_pretrained(self.model_name).to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment