Unverified Commit d4c2cb40 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Kill model archive maps (#4636)

* Kill model archive maps

* Fixup

* Also kill model_archive_map for MaskedBertPreTrainedModel

* Unhook config_archive_map

* Tokenizers: align with model id changes

* make style && make quality

* Fix CI
parent 47a551d1
......@@ -13,14 +13,15 @@ from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP = {
"google/electra-small-generator": "https://cdn.huggingface.co/google/electra-small-generator/tf_model.h5",
"google/electra-base-generator": "https://cdn.huggingface.co/google/electra-base-generator/tf_model.h5",
"google/electra-large-generator": "https://cdn.huggingface.co/google/electra-large-generator/tf_model.h5",
"google/electra-small-discriminator": "https://cdn.huggingface.co/google/electra-small-discriminator/tf_model.h5",
"google/electra-base-discriminator": "https://cdn.huggingface.co/google/electra-base-discriminator/tf_model.h5",
"google/electra-large-discriminator": "https://cdn.huggingface.co/google/electra-large-discriminator/tf_model.h5",
}
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/electra-small-generator",
"google/electra-base-generator",
"google/electra-large-generator",
"google/electra-small-discriminator",
"google/electra-base-discriminator",
"google/electra-large-discriminator",
# See all ELECTRA models at https://huggingface.co/models?filter=electra
]
class TFElectraEmbeddings(tf.keras.layers.Layer):
......@@ -160,7 +161,6 @@ class TFElectraGeneratorPredictions(tf.keras.layers.Layer):
class TFElectraPreTrainedModel(TFBertPreTrainedModel):
config_class = ElectraConfig
pretrained_model_archive_map = TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "electra"
def get_extended_attention_mask(self, attention_mask, input_shape):
......
......@@ -35,7 +35,9 @@ from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {}
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all Flaubert models at https://huggingface.co/models?filter=flaubert
]
FLAUBERT_START_DOCSTRING = r"""
......@@ -104,7 +106,6 @@ FLAUBERT_INPUTS_DOCSTRING = r"""
)
class TFFlaubertModel(TFXLMModel):
config_class = FlaubertConfig
pretrained_model_archive_map = TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......@@ -309,7 +310,6 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
)
class TFFlaubertWithLMHeadModel(TFXLMWithLMHeadModel):
config_class = FlaubertConfig
pretrained_model_archive_map = TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......@@ -323,7 +323,6 @@ class TFFlaubertWithLMHeadModel(TFXLMWithLMHeadModel):
)
class TFFlaubertForSequenceClassification(TFXLMForSequenceClassification):
config_class = FlaubertConfig
pretrained_model_archive_map = TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......
......@@ -37,13 +37,14 @@ from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {
"gpt2": "https://cdn.huggingface.co/gpt2-tf_model.h5",
"gpt2-medium": "https://cdn.huggingface.co/gpt2-medium-tf_model.h5",
"gpt2-large": "https://cdn.huggingface.co/gpt2-large-tf_model.h5",
"gpt2-xl": "https://cdn.huggingface.co/gpt2-xl-tf_model.h5",
"distilgpt2": "https://cdn.huggingface.co/distilgpt2-tf_model.h5",
}
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
"gpt2",
"gpt2-medium",
"gpt2-large",
"gpt2-xl",
"distilgpt2",
# See all GPT-2 models at https://huggingface.co/models?filter=gpt2
]
def gelu(x):
......@@ -389,7 +390,6 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel):
"""
config_class = GPT2Config
pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer"
......
......@@ -36,7 +36,10 @@ from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://cdn.huggingface.co/openai-gpt-tf_model.h5"}
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"openai-gpt",
# See all OpenAI GPT models at https://huggingface.co/models?filter=openai-gpt
]
def gelu(x):
......@@ -349,7 +352,6 @@ class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel):
"""
config_class = OpenAIGPTConfig
pretrained_model_archive_map = TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer"
......
......@@ -28,12 +28,13 @@ from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
logger = logging.getLogger(__name__)
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
"roberta-base": "https://cdn.huggingface.co/roberta-base-tf_model.h5",
"roberta-large": "https://cdn.huggingface.co/roberta-large-tf_model.h5",
"roberta-large-mnli": "https://cdn.huggingface.co/roberta-large-mnli-tf_model.h5",
"distilroberta-base": "https://cdn.huggingface.co/distilroberta-base-tf_model.h5",
}
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
"roberta-base",
"roberta-large",
"roberta-large-mnli",
"distilroberta-base",
# See all RoBERTa models at https://huggingface.co/models?filter=roberta
]
class TFRobertaEmbeddings(TFBertEmbeddings):
......@@ -100,7 +101,6 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel):
"""
config_class = RobertaConfig
pretrained_model_archive_map = TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "roberta"
......
......@@ -30,13 +30,14 @@ from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
logger = logging.getLogger(__name__)
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP = {
"t5-small": "https://cdn.huggingface.co/t5-small-tf_model.h5",
"t5-base": "https://cdn.huggingface.co/t5-base-tf_model.h5",
"t5-large": "https://cdn.huggingface.co/t5-large-tf_model.h5",
"t5-3b": "https://cdn.huggingface.co/t5-3b-tf_model.h5",
"t5-11b": "https://cdn.huggingface.co/t5-11b-tf_model.h5",
}
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
"t5-small",
"t5-base",
"t5-large",
"t5-3b",
"t5-11b",
# See all T5 models at https://huggingface.co/models?filter=t5
]
####################################################
# TF 2.0 Models are constructed using Keras imperative API by sub-classing
......@@ -720,7 +721,6 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
"""
config_class = T5Config
pretrained_model_archive_map = TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer"
@property
......
......@@ -30,9 +30,10 @@ from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
"transfo-xl-wt103": "https://cdn.huggingface.co/transfo-xl-wt103-tf_model.h5",
}
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = [
"transfo-xl-wt103",
# See all Transformer XL models at https://huggingface.co/models?filter=transfo-xl
]
class TFPositionalEmbedding(tf.keras.layers.Layer):
......@@ -630,7 +631,6 @@ class TFTransfoXLPreTrainedModel(TFPreTrainedModel):
"""
config_class = TransfoXLConfig
pretrained_model_archive_map = TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer"
......
......@@ -112,7 +112,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
Class attributes (overridden by derived classes):
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
- ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
......@@ -122,7 +121,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
"""
config_class = None
pretrained_model_archive_map = {}
base_model_prefix = ""
@property
......@@ -338,9 +336,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# Load model
if pretrained_model_name_or_path is not None:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
# Load from a TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
......@@ -364,8 +360,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
use_cdn=use_cdn,
)
# redirect to the cache, if necessary
try:
# Load from URL or cache if already cached
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
......@@ -373,20 +369,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
resume_download=resume_download,
proxies=proxies,
)
except EnvironmentError as e:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error("Couldn't reach server at '{}' to download pretrained weights.".format(archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
", ".join(cls.pretrained_model_archive_map.keys()),
archive_file,
)
)
raise e
if resolved_archive_file is None:
raise EnvironmentError
except EnvironmentError:
msg = (
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n"
)
raise EnvironmentError(msg)
if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file))
else:
......
......@@ -31,18 +31,19 @@ from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
"xlm-mlm-en-2048": "https://cdn.huggingface.co/xlm-mlm-en-2048-tf_model.h5",
"xlm-mlm-ende-1024": "https://cdn.huggingface.co/xlm-mlm-ende-1024-tf_model.h5",
"xlm-mlm-enfr-1024": "https://cdn.huggingface.co/xlm-mlm-enfr-1024-tf_model.h5",
"xlm-mlm-enro-1024": "https://cdn.huggingface.co/xlm-mlm-enro-1024-tf_model.h5",
"xlm-mlm-tlm-xnli15-1024": "https://cdn.huggingface.co/xlm-mlm-tlm-xnli15-1024-tf_model.h5",
"xlm-mlm-xnli15-1024": "https://cdn.huggingface.co/xlm-mlm-xnli15-1024-tf_model.h5",
"xlm-clm-enfr-1024": "https://cdn.huggingface.co/xlm-clm-enfr-1024-tf_model.h5",
"xlm-clm-ende-1024": "https://cdn.huggingface.co/xlm-clm-ende-1024-tf_model.h5",
"xlm-mlm-17-1280": "https://cdn.huggingface.co/xlm-mlm-17-1280-tf_model.h5",
"xlm-mlm-100-1280": "https://cdn.huggingface.co/xlm-mlm-100-1280-tf_model.h5",
}
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"xlm-mlm-en-2048",
"xlm-mlm-ende-1024",
"xlm-mlm-enfr-1024",
"xlm-mlm-enro-1024",
"xlm-mlm-tlm-xnli15-1024",
"xlm-mlm-xnli15-1024",
"xlm-clm-enfr-1024",
"xlm-clm-ende-1024",
"xlm-mlm-17-1280",
"xlm-mlm-100-1280",
# See all XLM models at https://huggingface.co/models?filter=xlm
]
def create_sinusoidal_embeddings(n_pos, dim, out):
......@@ -470,7 +471,6 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
"""
config_class = XLMConfig
pretrained_model_archive_map = TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer"
@property
......
......@@ -30,7 +30,9 @@ from .modeling_tf_roberta import (
logger = logging.getLogger(__name__)
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {}
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all XLM-RoBERTa models at https://huggingface.co/models?filter=xlm-roberta
]
XLM_ROBERTA_START_DOCSTRING = r"""
......@@ -72,7 +74,6 @@ class TFXLMRobertaModel(TFRobertaModel):
"""
config_class = XLMRobertaConfig
pretrained_model_archive_map = TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings(
......@@ -85,7 +86,6 @@ class TFXLMRobertaForMaskedLM(TFRobertaForMaskedLM):
"""
config_class = XLMRobertaConfig
pretrained_model_archive_map = TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings(
......@@ -100,7 +100,6 @@ class TFXLMRobertaForSequenceClassification(TFRobertaForSequenceClassification):
"""
config_class = XLMRobertaConfig
pretrained_model_archive_map = TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings(
......@@ -115,4 +114,3 @@ class TFXLMRobertaForTokenClassification(TFRobertaForTokenClassification):
"""
config_class = XLMRobertaConfig
pretrained_model_archive_map = TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
......@@ -37,10 +37,11 @@ from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
"xlnet-base-cased": "https://cdn.huggingface.co/xlnet-base-cased-tf_model.h5",
"xlnet-large-cased": "https://cdn.huggingface.co/xlnet-large-cased-tf_model.h5",
}
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
"xlnet-base-cased",
"xlnet-large-cased",
# See all XLNet models at https://huggingface.co/models?filter=xlnet
]
def gelu(x):
......@@ -701,7 +702,6 @@ class TFXLNetPreTrainedModel(TFPreTrainedModel):
"""
config_class = XLNetConfig
pretrained_model_archive_map = TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer"
......
......@@ -33,9 +33,10 @@ from .modeling_utils import PreTrainedModel
logger = logging.getLogger(__name__)
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
"transfo-xl-wt103": "https://cdn.huggingface.co/transfo-xl-wt103-pytorch_model.bin",
}
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = [
"transfo-xl-wt103",
# See all Transformer XL models at https://huggingface.co/models?filter=transfo-xl
]
def build_tf_to_pytorch_map(model, config):
......@@ -453,7 +454,6 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
"""
config_class = TransfoXLConfig
pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_transfo_xl
base_model_prefix = "transformer"
......
......@@ -110,6 +110,9 @@ class ModuleUtilsMixin:
@property
def device(self) -> device:
"""
Get torch.device from module, assuming that the whole module has one device.
"""
try:
return next(self.parameters()).device
except StopIteration:
......@@ -125,6 +128,9 @@ class ModuleUtilsMixin:
@property
def dtype(self) -> dtype:
"""
Get torch.dtype from module, assuming that the whole module has one dtype.
"""
try:
return next(self.parameters()).dtype
except StopIteration:
......@@ -249,7 +255,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
Class attributes (overridden by derived classes):
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
- ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
......@@ -259,7 +264,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
"""
config_class = None
pretrained_model_archive_map = {}
base_model_prefix = ""
@property
......@@ -587,9 +591,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# Load model
if pretrained_model_name_or_path is not None:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isdir(pretrained_model_name_or_path):
if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
# Load from a TF 1.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
......@@ -622,8 +624,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
use_cdn=use_cdn,
)
# redirect to the cache, if necessary
try:
# Load from URL or cache if already cached
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
......@@ -632,20 +634,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
resume_download=resume_download,
local_files_only=local_files_only,
)
if resolved_archive_file is None:
raise EnvironmentError
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained weights.".format(archive_file)
else:
msg = (
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url to model weight files named one of {} but "
"couldn't find any such file at this path or url.".format(
pretrained_model_name_or_path,
", ".join(cls.pretrained_model_archive_map.keys()),
archive_file,
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME],
)
)
msg = (
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
)
raise EnvironmentError(msg)
if resolved_archive_file == archive_file:
......
......@@ -34,18 +34,19 @@ from .modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead, prune_l
logger = logging.getLogger(__name__)
XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
"xlm-mlm-en-2048": "https://cdn.huggingface.co/xlm-mlm-en-2048-pytorch_model.bin",
"xlm-mlm-ende-1024": "https://cdn.huggingface.co/xlm-mlm-ende-1024-pytorch_model.bin",
"xlm-mlm-enfr-1024": "https://cdn.huggingface.co/xlm-mlm-enfr-1024-pytorch_model.bin",
"xlm-mlm-enro-1024": "https://cdn.huggingface.co/xlm-mlm-enro-1024-pytorch_model.bin",
"xlm-mlm-tlm-xnli15-1024": "https://cdn.huggingface.co/xlm-mlm-tlm-xnli15-1024-pytorch_model.bin",
"xlm-mlm-xnli15-1024": "https://cdn.huggingface.co/xlm-mlm-xnli15-1024-pytorch_model.bin",
"xlm-clm-enfr-1024": "https://cdn.huggingface.co/xlm-clm-enfr-1024-pytorch_model.bin",
"xlm-clm-ende-1024": "https://cdn.huggingface.co/xlm-clm-ende-1024-pytorch_model.bin",
"xlm-mlm-17-1280": "https://cdn.huggingface.co/xlm-mlm-17-1280-pytorch_model.bin",
"xlm-mlm-100-1280": "https://cdn.huggingface.co/xlm-mlm-100-1280-pytorch_model.bin",
}
XLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"xlm-mlm-en-2048",
"xlm-mlm-ende-1024",
"xlm-mlm-enfr-1024",
"xlm-mlm-enro-1024",
"xlm-mlm-tlm-xnli15-1024",
"xlm-mlm-xnli15-1024",
"xlm-clm-enfr-1024",
"xlm-clm-ende-1024",
"xlm-mlm-17-1280",
"xlm-mlm-100-1280",
# See all XLM models at https://huggingface.co/models?filter=xlm
]
def create_sinusoidal_embeddings(n_pos, dim, out):
......@@ -207,7 +208,6 @@ class XLMPreTrainedModel(PreTrainedModel):
"""
config_class = XLMConfig
pretrained_model_archive_map = XLM_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = None
base_model_prefix = "transformer"
......
......@@ -31,14 +31,15 @@ from .modeling_roberta import (
logger = logging.getLogger(__name__)
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
"xlm-roberta-base": "https://cdn.huggingface.co/xlm-roberta-base-pytorch_model.bin",
"xlm-roberta-large": "https://cdn.huggingface.co/xlm-roberta-large-pytorch_model.bin",
"xlm-roberta-large-finetuned-conll02-dutch": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll02-dutch-pytorch_model.bin",
"xlm-roberta-large-finetuned-conll02-spanish": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll02-spanish-pytorch_model.bin",
"xlm-roberta-large-finetuned-conll03-english": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-english-pytorch_model.bin",
"xlm-roberta-large-finetuned-conll03-german": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-german-pytorch_model.bin",
}
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
"xlm-roberta-base",
"xlm-roberta-large",
"xlm-roberta-large-finetuned-conll02-dutch",
"xlm-roberta-large-finetuned-conll02-spanish",
"xlm-roberta-large-finetuned-conll03-english",
"xlm-roberta-large-finetuned-conll03-german",
# See all XLM-RoBERTa models at https://huggingface.co/models?filter=xlm-roberta
]
XLM_ROBERTA_START_DOCSTRING = r"""
......@@ -65,7 +66,6 @@ class XLMRobertaModel(RobertaModel):
"""
config_class = XLMRobertaConfig
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings(
......@@ -78,7 +78,6 @@ class XLMRobertaForMaskedLM(RobertaForMaskedLM):
"""
config_class = XLMRobertaConfig
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings(
......@@ -93,7 +92,6 @@ class XLMRobertaForSequenceClassification(RobertaForSequenceClassification):
"""
config_class = XLMRobertaConfig
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings(
......@@ -108,7 +106,6 @@ class XLMRobertaForMultipleChoice(RobertaForMultipleChoice):
"""
config_class = XLMRobertaConfig
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings(
......@@ -123,4 +120,3 @@ class XLMRobertaForTokenClassification(RobertaForTokenClassification):
"""
config_class = XLMRobertaConfig
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
......@@ -32,10 +32,11 @@ from .modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogit
logger = logging.getLogger(__name__)
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
"xlnet-base-cased": "https://cdn.huggingface.co/xlnet-base-cased-pytorch_model.bin",
"xlnet-large-cased": "https://cdn.huggingface.co/xlnet-large-cased-pytorch_model.bin",
}
XLNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
"xlnet-base-cased",
"xlnet-large-cased",
# See all XLNet models at https://huggingface.co/models?filter=xlnet
]
def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
......@@ -459,7 +460,6 @@ class XLNetPreTrainedModel(PreTrainedModel):
"""
config_class = XLNetConfig
pretrained_model_archive_map = XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_xlnet
base_model_prefix = "transformer"
......
......@@ -97,27 +97,24 @@ class AutoTokenizer:
when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)`
class method.
The `from_pretrained()` method take care of returning the correct tokenizer class instance
The `from_pretrained()` method takes care of returning the correct tokenizer class instance
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The tokenizer class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: T5Tokenizer (T5 model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `albert`: AlbertTokenizer (ALBERT model)
- contains `camembert`: CamembertTokenizer (CamemBERT model)
- contains `xlm-roberta`: XLMRobertaTokenizer (XLM-RoBERTa model)
- contains `longformer`: LongformerTokenizer (AllenAI Longformer model)
- contains `roberta`: RobertaTokenizer (RoBERTa model)
- contains `bert`: BertTokenizer (Bert model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlm`: XLMTokenizer (XLM model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
- contains `electra`: ElectraTokenizer (Google ELECTRA model)
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
- `t5`: T5Tokenizer (T5 model)
- `distilbert`: DistilBertTokenizer (DistilBert model)
- `albert`: AlbertTokenizer (ALBERT model)
- `camembert`: CamembertTokenizer (CamemBERT model)
- `xlm-roberta`: XLMRobertaTokenizer (XLM-RoBERTa model)
- `longformer`: LongformerTokenizer (AllenAI Longformer model)
- `roberta`: RobertaTokenizer (RoBERTa model)
- `bert`: BertTokenizer (Bert model)
- `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- `xlnet`: XLNetTokenizer (XLNet model)
- `xlm`: XLMTokenizer (XLM model)
- `ctrl`: CTRLTokenizer (Salesforce CTRL model)
- `electra`: ElectraTokenizer (Google ELECTRA model)
This class cannot be instantiated using `__init__()` (throw an error).
"""
......@@ -133,24 +130,25 @@ class AutoTokenizer:
r""" Instantiate one of the tokenizer classes of the library
from a pre-trained model vocabulary.
The tokenizer class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: T5Tokenizer (T5 model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `albert`: AlbertTokenizer (ALBERT model)
- contains `camembert`: CamembertTokenizer (CamemBERT model)
- contains `xlm-roberta`: XLMRobertaTokenizer (XLM-RoBERTa model)
- contains `longformer`: LongformerTokenizer (AllenAI Longformer model)
- contains `roberta`: RobertaTokenizer (RoBERTa model)
- contains `bert-base-japanese`: BertJapaneseTokenizer (Bert model)
- contains `bert`: BertTokenizer (Bert model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlm`: XLMTokenizer (XLM model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
- contains `electra`: ElectraTokenizer (Google ELECTRA model)
The tokenizer class to instantiate is selected
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
- `t5`: T5Tokenizer (T5 model)
- `distilbert`: DistilBertTokenizer (DistilBert model)
- `albert`: AlbertTokenizer (ALBERT model)
- `camembert`: CamembertTokenizer (CamemBERT model)
- `xlm-roberta`: XLMRobertaTokenizer (XLM-RoBERTa model)
- `longformer`: LongformerTokenizer (AllenAI Longformer model)
- `roberta`: RobertaTokenizer (RoBERTa model)
- `bert-base-japanese`: BertJapaneseTokenizer (Bert model)
- `bert`: BertTokenizer (Bert model)
- `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- `xlnet`: XLNetTokenizer (XLNet model)
- `xlm`: XLMTokenizer (XLM model)
- `ctrl`: CTRLTokenizer (Salesforce CTRL model)
- `electra`: ElectraTokenizer (Google ELECTRA model)
Params:
pretrained_model_name_or_path: either:
......
......@@ -47,9 +47,9 @@ PRETRAINED_VOCAB_FILES_MAP = {
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
"bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
"bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
"bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
"bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
"bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt",
"TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
"TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
"wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt",
}
}
......@@ -69,9 +69,9 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"bert-base-cased-finetuned-mrpc": 512,
"bert-base-german-dbmdz-cased": 512,
"bert-base-german-dbmdz-uncased": 512,
"bert-base-finnish-cased-v1": 512,
"bert-base-finnish-uncased-v1": 512,
"bert-base-dutch-cased": 512,
"TurkuNLP/bert-base-finnish-cased-v1": 512,
"TurkuNLP/bert-base-finnish-uncased-v1": 512,
"wietsedv/bert-base-dutch-cased": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
......@@ -90,9 +90,9 @@ PRETRAINED_INIT_CONFIGURATION = {
"bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
"bert-base-german-dbmdz-cased": {"do_lower_case": False},
"bert-base-german-dbmdz-uncased": {"do_lower_case": True},
"bert-base-finnish-cased-v1": {"do_lower_case": False},
"bert-base-finnish-uncased-v1": {"do_lower_case": True},
"bert-base-dutch-cased": {"do_lower_case": False},
"TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
"TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
"wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
}
......
......@@ -30,37 +30,37 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/vocab.txt",
"bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/vocab.txt",
"bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/vocab.txt",
"bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/vocab.txt",
"cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/vocab.txt",
"cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/vocab.txt",
"cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/vocab.txt",
"cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/vocab.txt",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"bert-base-japanese": 512,
"bert-base-japanese-whole-word-masking": 512,
"bert-base-japanese-char": 512,
"bert-base-japanese-char-whole-word-masking": 512,
"cl-tohoku/bert-base-japanese": 512,
"cl-tohoku/bert-base-japanese-whole-word-masking": 512,
"cl-tohoku/bert-base-japanese-char": 512,
"cl-tohoku/bert-base-japanese-char-whole-word-masking": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
"bert-base-japanese": {
"cl-tohoku/bert-base-japanese": {
"do_lower_case": False,
"word_tokenizer_type": "mecab",
"subword_tokenizer_type": "wordpiece",
},
"bert-base-japanese-whole-word-masking": {
"cl-tohoku/bert-base-japanese-whole-word-masking": {
"do_lower_case": False,
"word_tokenizer_type": "mecab",
"subword_tokenizer_type": "wordpiece",
},
"bert-base-japanese-char": {
"cl-tohoku/bert-base-japanese-char": {
"do_lower_case": False,
"word_tokenizer_type": "mecab",
"subword_tokenizer_type": "character",
},
"bert-base-japanese-char-whole-word-masking": {
"cl-tohoku/bert-base-japanese-char-whole-word-masking": {
"do_lower_case": False,
"word_tokenizer_type": "mecab",
"subword_tokenizer_type": "character",
......
......@@ -942,13 +942,11 @@ class PreTrainedTokenizer(SpecialTokensMixin):
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
if len(cls.vocab_files_names) > 1:
raise ValueError(
"Calling {}.from_pretrained() with the path to a single file or url is not supported."
"Use a model identifier or the path to a directory instead.".format(cls.__name__)
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not supported."
"Use a model identifier or the path to a directory instead."
)
logger.warning(
"Calling {}.from_pretrained() with the path to a single file or url is deprecated".format(
cls.__name__
)
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is deprecated"
)
file_id = list(cls.vocab_files_names.keys())[0]
vocab_files[file_id] = pretrained_model_name_or_path
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment