"examples/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "1ee97350ca8a19bb3d780dea3a80135ee417a578"
Commit 4f15e5a2 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Add tests.

Maybe not the best possible place for the tests, lmk.
parent 18e1f751
...@@ -22,7 +22,7 @@ import logging ...@@ -22,7 +22,7 @@ import logging
from transformers import is_torch_available from transformers import is_torch_available
from .utils import require_torch, slow from .utils import require_torch, slow, SMALL_MODEL_IDENTIFIER
if is_torch_available(): if is_torch_available():
from transformers import (AutoConfig, BertConfig, from transformers import (AutoConfig, BertConfig,
...@@ -92,6 +92,11 @@ class AutoModelTest(unittest.TestCase): ...@@ -92,6 +92,11 @@ class AutoModelTest(unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, BertForQuestionAnswering) self.assertIsInstance(model, BertForQuestionAnswering)
def test_from_pretrained_identifier(self):
logging.basicConfig(level=logging.INFO)
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, BertForMaskedLM)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -22,7 +22,7 @@ import logging ...@@ -22,7 +22,7 @@ import logging
from transformers import is_tf_available from transformers import is_tf_available
from .utils import require_tf, slow from .utils import require_tf, slow, SMALL_MODEL_IDENTIFIER
if is_tf_available(): if is_tf_available():
from transformers import (AutoConfig, BertConfig, from transformers import (AutoConfig, BertConfig,
...@@ -93,6 +93,11 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -93,6 +93,11 @@ class TFAutoModelTest(unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertForQuestionAnswering) self.assertIsInstance(model, TFBertForQuestionAnswering)
def test_from_pretrained_identifier(self):
logging.basicConfig(level=logging.INFO)
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, force_download=True)
self.assertIsInstance(model, TFBertForMaskedLM)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -23,7 +23,7 @@ import logging ...@@ -23,7 +23,7 @@ import logging
from transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer from transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer
from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
from .utils import slow from .utils import slow, SMALL_MODEL_IDENTIFIER
class AutoTokenizerTest(unittest.TestCase): class AutoTokenizerTest(unittest.TestCase):
...@@ -42,6 +42,11 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -42,6 +42,11 @@ class AutoTokenizerTest(unittest.TestCase):
self.assertIsInstance(tokenizer, GPT2Tokenizer) self.assertIsInstance(tokenizer, GPT2Tokenizer)
self.assertGreater(len(tokenizer), 0) self.assertGreater(len(tokenizer), 0)
def test_tokenizer_from_pretrained_identifier(self):
logging.basicConfig(level=logging.INFO)
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(tokenizer, BertTokenizer)
self.assertEqual(len(tokenizer), 12)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -6,6 +6,9 @@ from distutils.util import strtobool ...@@ -6,6 +6,9 @@ from distutils.util import strtobool
from transformers.file_utils import _tf_available, _torch_available from transformers.file_utils import _tf_available, _torch_available
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
try: try:
run_slow = os.environ["RUN_SLOW"] run_slow = os.environ["RUN_SLOW"]
except KeyError: except KeyError:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment