Commit 4f15e5a2 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Add tests.

Maybe not the best possible place for the tests, lmk.
parent 18e1f751
...@@ -22,7 +22,7 @@ import logging ...@@ -22,7 +22,7 @@ import logging
from transformers import is_torch_available from transformers import is_torch_available
from .utils import require_torch, slow from .utils import require_torch, slow, SMALL_MODEL_IDENTIFIER
if is_torch_available(): if is_torch_available():
from transformers import (AutoConfig, BertConfig, from transformers import (AutoConfig, BertConfig,
...@@ -92,6 +92,11 @@ class AutoModelTest(unittest.TestCase): ...@@ -92,6 +92,11 @@ class AutoModelTest(unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, BertForQuestionAnswering) self.assertIsInstance(model, BertForQuestionAnswering)
def test_from_pretrained_identifier(self):
logging.basicConfig(level=logging.INFO)
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, BertForMaskedLM)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -22,7 +22,7 @@ import logging ...@@ -22,7 +22,7 @@ import logging
from transformers import is_tf_available from transformers import is_tf_available
from .utils import require_tf, slow from .utils import require_tf, slow, SMALL_MODEL_IDENTIFIER
if is_tf_available(): if is_tf_available():
from transformers import (AutoConfig, BertConfig, from transformers import (AutoConfig, BertConfig,
...@@ -93,6 +93,11 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -93,6 +93,11 @@ class TFAutoModelTest(unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertForQuestionAnswering) self.assertIsInstance(model, TFBertForQuestionAnswering)
def test_from_pretrained_identifier(self):
logging.basicConfig(level=logging.INFO)
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, force_download=True)
self.assertIsInstance(model, TFBertForMaskedLM)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -23,7 +23,7 @@ import logging ...@@ -23,7 +23,7 @@ import logging
from transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer from transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer
from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
from .utils import slow from .utils import slow, SMALL_MODEL_IDENTIFIER
class AutoTokenizerTest(unittest.TestCase): class AutoTokenizerTest(unittest.TestCase):
...@@ -42,6 +42,11 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -42,6 +42,11 @@ class AutoTokenizerTest(unittest.TestCase):
self.assertIsInstance(tokenizer, GPT2Tokenizer) self.assertIsInstance(tokenizer, GPT2Tokenizer)
self.assertGreater(len(tokenizer), 0) self.assertGreater(len(tokenizer), 0)
def test_tokenizer_from_pretrained_identifier(self):
logging.basicConfig(level=logging.INFO)
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(tokenizer, BertTokenizer)
self.assertEqual(len(tokenizer), 12)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -6,6 +6,9 @@ from distutils.util import strtobool ...@@ -6,6 +6,9 @@ from distutils.util import strtobool
from transformers.file_utils import _tf_available, _torch_available from transformers.file_utils import _tf_available, _torch_available
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
try: try:
run_slow = os.environ["RUN_SLOW"] run_slow = os.environ["RUN_SLOW"]
except KeyError: except KeyError:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment