Commit cf8a70bf authored by Julien Chaumond's avatar Julien Chaumond
Browse files

More AutoConfig tests

parent 6bb3edc3
...@@ -20,6 +20,8 @@ from transformers.configuration_auto import AutoConfig ...@@ -20,6 +20,8 @@ from transformers.configuration_auto import AutoConfig
from transformers.configuration_bert import BertConfig from transformers.configuration_bert import BertConfig
from transformers.configuration_roberta import RobertaConfig from transformers.configuration_roberta import RobertaConfig
from .utils import DUMMY_UNKWOWN_IDENTIFIER
SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json") SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
...@@ -29,10 +31,14 @@ class AutoConfigTest(unittest.TestCase): ...@@ -29,10 +31,14 @@ class AutoConfigTest(unittest.TestCase):
config = AutoConfig.from_pretrained("bert-base-uncased") config = AutoConfig.from_pretrained("bert-base-uncased")
self.assertIsInstance(config, BertConfig) self.assertIsInstance(config, BertConfig)
def test_config_from_model_type(self): def test_config_model_type_from_local_file(self):
config = AutoConfig.from_pretrained(SAMPLE_ROBERTA_CONFIG) config = AutoConfig.from_pretrained(SAMPLE_ROBERTA_CONFIG)
self.assertIsInstance(config, RobertaConfig) self.assertIsInstance(config, RobertaConfig)
def test_config_model_type_from_model_identifier(self):
config = AutoConfig.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
self.assertIsInstance(config, RobertaConfig)
def test_config_for_model_str(self): def test_config_for_model_str(self):
config = AutoConfig.for_model("roberta") config = AutoConfig.for_model("roberta")
self.assertIsInstance(config, RobertaConfig) self.assertIsInstance(config, RobertaConfig)
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .utils import SMALL_MODEL_IDENTIFIER, require_torch, slow from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_torch, slow
if is_torch_available(): if is_torch_available():
...@@ -30,6 +30,7 @@ if is_torch_available(): ...@@ -30,6 +30,7 @@ if is_torch_available():
BertModel, BertModel,
AutoModelWithLMHead, AutoModelWithLMHead,
BertForMaskedLM, BertForMaskedLM,
RobertaForMaskedLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
BertForSequenceClassification, BertForSequenceClassification,
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
...@@ -102,3 +103,10 @@ class AutoModelTest(unittest.TestCase): ...@@ -102,3 +103,10 @@ class AutoModelTest(unittest.TestCase):
self.assertIsInstance(model, BertForMaskedLM) self.assertIsInstance(model, BertForMaskedLM)
self.assertEqual(model.num_parameters(), 14830) self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830) self.assertEqual(model.num_parameters(only_trainable=True), 14830)
def test_from_identifier_from_model_type(self):
logging.basicConfig(level=logging.INFO)
model = AutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
self.assertIsInstance(model, RobertaForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
from transformers import is_tf_available from transformers import is_tf_available
from .utils import SMALL_MODEL_IDENTIFIER, require_tf, slow from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_tf, slow
if is_tf_available(): if is_tf_available():
...@@ -30,6 +30,7 @@ if is_tf_available(): ...@@ -30,6 +30,7 @@ if is_tf_available():
TFBertModel, TFBertModel,
TFAutoModelWithLMHead, TFAutoModelWithLMHead,
TFBertForMaskedLM, TFBertForMaskedLM,
TFRobertaForMaskedLM,
TFAutoModelForSequenceClassification, TFAutoModelForSequenceClassification,
TFBertForSequenceClassification, TFBertForSequenceClassification,
TFAutoModelForQuestionAnswering, TFAutoModelForQuestionAnswering,
...@@ -101,3 +102,10 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -101,3 +102,10 @@ class TFAutoModelTest(unittest.TestCase):
self.assertIsInstance(model, TFBertForMaskedLM) self.assertIsInstance(model, TFBertForMaskedLM)
self.assertEqual(model.num_parameters(), 14830) self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830) self.assertEqual(model.num_parameters(only_trainable=True), 14830)
def test_from_identifier_from_model_type(self):
logging.basicConfig(level=logging.INFO)
model = TFAutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
self.assertIsInstance(model, TFRobertaForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
...@@ -23,9 +23,10 @@ from transformers import ( ...@@ -23,9 +23,10 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
BertTokenizer, BertTokenizer,
GPT2Tokenizer, GPT2Tokenizer,
RobertaTokenizer,
) )
from .utils import SMALL_MODEL_IDENTIFIER, slow # noqa: F401 from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, slow # noqa: F401
class AutoTokenizerTest(unittest.TestCase): class AutoTokenizerTest(unittest.TestCase):
...@@ -49,3 +50,9 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -49,3 +50,9 @@ class AutoTokenizerTest(unittest.TestCase):
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER) tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(tokenizer, BertTokenizer) self.assertIsInstance(tokenizer, BertTokenizer)
self.assertEqual(len(tokenizer), 12) self.assertEqual(len(tokenizer), 12)
def test_tokenizer_from_model_type(self):
logging.basicConfig(level=logging.INFO)
tokenizer = AutoTokenizer.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
self.assertIsInstance(tokenizer, RobertaTokenizer)
self.assertEqual(len(tokenizer), 20)
...@@ -9,6 +9,8 @@ from transformers.file_utils import _tf_available, _torch_available ...@@ -9,6 +9,8 @@ from transformers.file_utils import _tf_available, _torch_available
CACHE_DIR = os.path.join(tempfile.gettempdir(), "transformers_test") CACHE_DIR = os.path.join(tempfile.gettempdir(), "transformers_test")
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown"
# Used to test Auto{Config, Model, Tokenizer} model_type detection.
def parse_flag_from_env(key, default=False): def parse_flag_from_env(key, default=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment