Commit 4d1c98c0 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

AutoConfig + other Auto classes honor model_type

parent 2f32dfd3
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import logging import logging
from collections import OrderedDict
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
...@@ -27,6 +28,7 @@ from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, Open ...@@ -27,6 +28,7 @@ from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, Open
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
from .configuration_utils import PretrainedConfig
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
...@@ -56,17 +58,38 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( ...@@ -56,17 +58,38 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
) )
class AutoConfig(object): CONFIG_MAPPING = OrderedDict(
[
("t5", T5Config,),
("distilbert", DistilBertConfig,),
("albert", AlbertConfig,),
("camembert", CamembertConfig,),
("xlm-roberta", XLMRobertaConfig,),
("roberta", RobertaConfig,),
("bert", BertConfig,),
("openai-gpt", OpenAIGPTConfig,),
("gpt2", GPT2Config,),
("transfo-xl", TransfoXLConfig,),
("xlnet", XLNetConfig,),
("xlm", XLMConfig,),
("ctrl", CTRLConfig,),
]
)
class AutoConfig:
r""":class:`~transformers.AutoConfig` is a generic configuration class r""":class:`~transformers.AutoConfig` is a generic configuration class
that will be instantiated as one of the configuration classes of the library that will be instantiated as one of the configuration classes of the library
when created with the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` when created with the `AutoConfig.from_pretrained(pretrained_model_name_or_path)`
class method. class method.
The `from_pretrained()` method take care of returning the correct model class instance The `from_pretrained()` method take care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The base model class to instantiate is selected as the first pattern matching When using string matching, the configuration class is matched on
in the `pretrained_model_name_or_path` string (in the following order): the `pretrained_model_name_or_path` string in the following order:
- contains `t5`: T5Config (T5 model)
- contains `distilbert`: DistilBertConfig (DistilBERT model) - contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `albert`: AlbertConfig (ALBERT model) - contains `albert`: AlbertConfig (ALBERT model)
- contains `camembert`: CamembertConfig (CamemBERT model) - contains `camembert`: CamembertConfig (CamemBERT model)
...@@ -90,41 +113,23 @@ class AutoConfig(object): ...@@ -90,41 +113,23 @@ class AutoConfig(object):
@classmethod @classmethod
def for_model(cls, model_type, *args, **kwargs): def for_model(cls, model_type, *args, **kwargs):
if "distilbert" in model_type: for pattern, config_class in CONFIG_MAPPING.items():
return DistilBertConfig(*args, **kwargs) if pattern in model_type:
elif "roberta" in model_type: return config_class(*args, **kwargs)
return RobertaConfig(*args, **kwargs)
elif "bert" in model_type:
return BertConfig(*args, **kwargs)
elif "openai-gpt" in model_type:
return OpenAIGPTConfig(*args, **kwargs)
elif "gpt2" in model_type:
return GPT2Config(*args, **kwargs)
elif "transfo-xl" in model_type:
return TransfoXLConfig(*args, **kwargs)
elif "xlnet" in model_type:
return XLNetConfig(*args, **kwargs)
elif "xlm" in model_type:
return XLMConfig(*args, **kwargs)
elif "ctrl" in model_type:
return CTRLConfig(*args, **kwargs)
elif "albert" in model_type:
return AlbertConfig(*args, **kwargs)
elif "camembert" in model_type:
return CamembertConfig(*args, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized model identifier in {}. Should contain one of {}".format(
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " model_type, ", ".join(CONFIG_MAPPING.keys())
"'xlm', 'roberta', 'ctrl', 'camembert', 'albert'".format(model_type) )
) )
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" Instantiate a one of the configuration classes of the library r""" Instantiate one of the configuration classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The configuration class to instantiate is selected as the first pattern matching The configuration class to instantiate is selected
in the `pretrained_model_name_or_path` string (in the following order): based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
- contains `t5`: T5Config (T5 model) - contains `t5`: T5Config (T5 model)
- contains `distilbert`: DistilBertConfig (DistilBERT model) - contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `albert`: AlbertConfig (ALBERT model) - contains `albert`: AlbertConfig (ALBERT model)
...@@ -183,36 +188,21 @@ class AutoConfig(object): ...@@ -183,36 +188,21 @@ class AutoConfig(object):
assert unused_kwargs == {'foo': False} assert unused_kwargs == {'foo': False}
""" """
if "t5" in pretrained_model_name_or_path: config_dict, _ = PretrainedConfig.resolved_config_dict(
return T5Config.from_pretrained(pretrained_model_name_or_path, **kwargs) pretrained_model_name_or_path, pretrained_config_archive_map=ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, **kwargs
elif "distilbert" in pretrained_model_name_or_path: )
return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "albert" in pretrained_model_name_or_path: if "model_type" in config_dict:
return AlbertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config_class = CONFIG_MAPPING[config_dict["model_type"]]
elif "camembert" in pretrained_model_name_or_path: return config_class.from_dict(config_dict, **kwargs)
return CamembertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) else:
elif "xlm-roberta" in pretrained_model_name_or_path: # Fallback: use pattern matching on the string.
return XLMRobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) for pattern, config_class in CONFIG_MAPPING.items():
elif "roberta" in pretrained_model_name_or_path: if pattern in pretrained_model_name_or_path:
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return config_class.from_dict(config_dict, **kwargs)
elif "bert" in pretrained_model_name_or_path:
return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "openai-gpt" in pretrained_model_name_or_path:
return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "gpt2" in pretrained_model_name_or_path:
return GPT2Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "transfo-xl" in pretrained_model_name_or_path:
return TransfoXLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "xlnet" in pretrained_model_name_or_path:
return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "xlm" in pretrained_model_name_or_path:
return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "ctrl" in pretrained_model_name_or_path:
return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized model identifier in {}. Should have a `model_type` key in its config.json, or contain one of {}".format(
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " pretrained_model_name_or_path, ", ".join(CONFIG_MAPPING.keys())
"'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'".format(
pretrained_model_name_or_path
) )
) )
...@@ -20,6 +20,7 @@ import copy ...@@ -20,6 +20,7 @@ import copy
import json import json
import logging import logging
import os import os
from typing import Dict, Optional, Tuple
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
...@@ -36,7 +37,7 @@ class PretrainedConfig(object): ...@@ -36,7 +37,7 @@ class PretrainedConfig(object):
It only affects the model's configuration. It only affects the model's configuration.
Class attributes (overridden by derived classes): Class attributes (overridden by derived classes):
- ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values. - ``pretrained_config_archive_map``: a python ``dict`` with `shortcut names` (string) as keys and `url` (string) of associated pretrained model configurations as values.
Parameters: Parameters:
``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. ``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
...@@ -154,14 +155,32 @@ class PretrainedConfig(object): ...@@ -154,14 +155,32 @@ class PretrainedConfig(object):
assert unused_kwargs == {'foo': False} assert unused_kwargs == {'foo': False}
""" """
config_dict, kwargs = cls.resolved_config_dict(pretrained_model_name_or_path, **kwargs)
return cls.from_dict(config_dict, **kwargs)
@classmethod
def resolved_config_dict(
cls, pretrained_model_name_or_path: str, pretrained_config_archive_map: Optional[Dict] = None, **kwargs
) -> Tuple[Dict, Dict]:
"""
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used
for instantiating a Config using `from_dict`.
Parameters:
pretrained_config_archive_map: (`optional`) Dict:
A map of `shortcut names` to `url`.
By default, will use the current class attribute.
"""
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_config_archive_map is None:
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] pretrained_config_archive_map = cls.pretrained_config_archive_map
if pretrained_model_name_or_path in pretrained_config_archive_map:
config_file = pretrained_config_archive_map[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path): elif os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
...@@ -178,23 +197,20 @@ class PretrainedConfig(object): ...@@ -178,23 +197,20 @@ class PretrainedConfig(object):
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
) )
# Load config # Load config dict
config = cls.from_json_file(resolved_config_file) config_dict = cls._dict_from_json_file(resolved_config_file)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_model_name_or_path in pretrained_config_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format( msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file config_file
) )
else: else:
msg = ( msg = (
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list. "
"We assumed '{}' was a path or url to a configuration file named {} or " "We assumed '{}' was a path or url to a configuration file named {} or "
"a directory containing such a file but couldn't find any such file at this path or url.".format( "a directory containing such a file but couldn't find any such file at this path or url.".format(
pretrained_model_name_or_path, pretrained_model_name_or_path, config_file, CONFIG_NAME,
", ".join(cls.pretrained_config_archive_map.keys()),
config_file,
CONFIG_NAME,
) )
) )
raise EnvironmentError(msg) raise EnvironmentError(msg)
...@@ -212,6 +228,15 @@ class PretrainedConfig(object): ...@@ -212,6 +228,15 @@ class PretrainedConfig(object):
else: else:
logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file)) logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
return config_dict, kwargs
@classmethod
def from_dict(cls, config_dict: Dict, **kwargs):
"""Constructs a `Config` from a Python dictionary of parameters."""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
config = cls(**config_dict)
if hasattr(config, "pruned_heads"): if hasattr(config, "pruned_heads"):
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
...@@ -231,17 +256,16 @@ class PretrainedConfig(object): ...@@ -231,17 +256,16 @@ class PretrainedConfig(object):
return config return config
@classmethod @classmethod
def from_dict(cls, json_object): def from_json_file(cls, json_file: str):
"""Constructs a `Config` from a Python dictionary of parameters.""" """Constructs a `Config` from a json file of parameters."""
return cls(**json_object) config_dict = cls._dict_from_json_file(json_file)
return cls(**config_dict)
@classmethod @classmethod
def from_json_file(cls, json_file): def _dict_from_json_file(cls, json_file: str):
"""Constructs a `Config` from a json file of parameters."""
with open(json_file, "r", encoding="utf-8") as reader: with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read() text = reader.read()
dict_obj = json.loads(text) return json.loads(text)
return cls(**dict_obj)
def __eq__(self, other): def __eq__(self, other):
return self.__dict__ == other.__dict__ return self.__dict__ == other.__dict__
......
This diff is collapsed.
This diff is collapsed.
...@@ -17,6 +17,23 @@ ...@@ -17,6 +17,23 @@
import logging import logging
from .configuration_auto import (
AlbertConfig,
AutoConfig,
BertConfig,
CamembertConfig,
CTRLConfig,
DistilBertConfig,
GPT2Config,
OpenAIGPTConfig,
RobertaConfig,
T5Config,
TransfoXLConfig,
XLMConfig,
XLMRobertaConfig,
XLNetConfig,
)
from .configuration_utils import PretrainedConfig
from .tokenization_albert import AlbertTokenizer from .tokenization_albert import AlbertTokenizer
from .tokenization_bert import BertTokenizer from .tokenization_bert import BertTokenizer
from .tokenization_bert_japanese import BertJapaneseTokenizer from .tokenization_bert_japanese import BertJapaneseTokenizer
...@@ -43,7 +60,8 @@ class AutoTokenizer(object): ...@@ -43,7 +60,8 @@ class AutoTokenizer(object):
class method. class method.
The `from_pretrained()` method take care of returning the correct tokenizer class instance The `from_pretrained()` method take care of returning the correct tokenizer class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The tokenizer class to instantiate is selected as the first pattern matching The tokenizer class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
...@@ -72,7 +90,7 @@ class AutoTokenizer(object): ...@@ -72,7 +90,7 @@ class AutoTokenizer(object):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
r""" Instantiate a one of the tokenizer classes of the library r""" Instantiate one of the tokenizer classes of the library
from a pre-trained model vocabulary. from a pre-trained model vocabulary.
The tokenizer class to instantiate is selected as the first pattern matching The tokenizer class to instantiate is selected as the first pattern matching
...@@ -129,33 +147,38 @@ class AutoTokenizer(object): ...@@ -129,33 +147,38 @@ class AutoTokenizer(object):
tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/')
""" """
if "t5" in pretrained_model_name_or_path: config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if "bert-base-japanese" in pretrained_model_name_or_path:
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
if isinstance(config, T5Config):
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "distilbert" in pretrained_model_name_or_path: elif isinstance(config, DistilBertConfig):
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "albert" in pretrained_model_name_or_path: elif isinstance(config, AlbertConfig):
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "camembert" in pretrained_model_name_or_path: elif isinstance(config, CamembertConfig):
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "xlm-roberta" in pretrained_model_name_or_path: elif isinstance(config, XLMRobertaConfig):
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "roberta" in pretrained_model_name_or_path: elif isinstance(config, RobertaConfig):
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "bert-base-japanese" in pretrained_model_name_or_path: elif isinstance(config, BertConfig):
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "bert" in pretrained_model_name_or_path:
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "openai-gpt" in pretrained_model_name_or_path: elif isinstance(config, OpenAIGPTConfig):
return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "gpt2" in pretrained_model_name_or_path: elif isinstance(config, GPT2Config):
return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "transfo-xl" in pretrained_model_name_or_path: elif isinstance(config, TransfoXLConfig):
return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "xlnet" in pretrained_model_name_or_path: elif isinstance(config, XLNetConfig):
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "xlm" in pretrained_model_name_or_path: elif isinstance(config, XLMConfig):
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "ctrl" in pretrained_model_name_or_path: elif isinstance(config, CTRLConfig):
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized model identifier in {}. Should contains one of "
......
{
"model_type": "roberta"
}
\ No newline at end of file
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from transformers.configuration_auto import AutoConfig
from transformers.configuration_bert import BertConfig
from transformers.configuration_roberta import RobertaConfig
SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
class AutoConfigTest(unittest.TestCase):
def test_config_from_model_shortcut(self):
config = AutoConfig.from_pretrained("bert-base-uncased")
self.assertIsInstance(config, BertConfig)
def test_config_from_model_type(self):
config = AutoConfig.from_pretrained(SAMPLE_ROBERTA_CONFIG)
self.assertIsInstance(config, RobertaConfig)
def test_config_for_model_str(self):
config = AutoConfig.for_model("roberta")
self.assertIsInstance(config, RobertaConfig)
...@@ -83,7 +83,7 @@ class AutoModelTest(unittest.TestCase): ...@@ -83,7 +83,7 @@ class AutoModelTest(unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, BertForSequenceClassification) self.assertIsInstance(model, BertForSequenceClassification)
@slow # @slow
def test_question_answering_model_from_pretrained(self): def test_question_answering_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
......
...@@ -29,7 +29,7 @@ from .utils import SMALL_MODEL_IDENTIFIER, slow ...@@ -29,7 +29,7 @@ from .utils import SMALL_MODEL_IDENTIFIER, slow
class AutoTokenizerTest(unittest.TestCase): class AutoTokenizerTest(unittest.TestCase):
@slow # @slow
def test_tokenizer_from_pretrained(self): def test_tokenizer_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment