"vscode:/vscode.git/clone" did not exist on "aac35514075290a46419b9bd969e6f94fef9d43b"
Unverified Commit d4c2cb40 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Kill model archive maps (#4636)

* Kill model archive maps

* Fixup

* Also kill model_archive_map for MaskedBertPreTrainedModel

* Unhook config_archive_map

* Tokenizers: align with model id changes

* make style && make quality

* Fix CI
parent 47a551d1
...@@ -13,14 +13,15 @@ from .tokenization_utils import BatchEncoding ...@@ -13,14 +13,15 @@ from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP = { TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/electra-small-generator": "https://cdn.huggingface.co/google/electra-small-generator/tf_model.h5", "google/electra-small-generator",
"google/electra-base-generator": "https://cdn.huggingface.co/google/electra-base-generator/tf_model.h5", "google/electra-base-generator",
"google/electra-large-generator": "https://cdn.huggingface.co/google/electra-large-generator/tf_model.h5", "google/electra-large-generator",
"google/electra-small-discriminator": "https://cdn.huggingface.co/google/electra-small-discriminator/tf_model.h5", "google/electra-small-discriminator",
"google/electra-base-discriminator": "https://cdn.huggingface.co/google/electra-base-discriminator/tf_model.h5", "google/electra-base-discriminator",
"google/electra-large-discriminator": "https://cdn.huggingface.co/google/electra-large-discriminator/tf_model.h5", "google/electra-large-discriminator",
} # See all ELECTRA models at https://huggingface.co/models?filter=electra
]
class TFElectraEmbeddings(tf.keras.layers.Layer): class TFElectraEmbeddings(tf.keras.layers.Layer):
...@@ -160,7 +161,6 @@ class TFElectraGeneratorPredictions(tf.keras.layers.Layer): ...@@ -160,7 +161,6 @@ class TFElectraGeneratorPredictions(tf.keras.layers.Layer):
class TFElectraPreTrainedModel(TFBertPreTrainedModel): class TFElectraPreTrainedModel(TFBertPreTrainedModel):
config_class = ElectraConfig config_class = ElectraConfig
pretrained_model_archive_map = TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "electra" base_model_prefix = "electra"
def get_extended_attention_mask(self, attention_mask, input_shape): def get_extended_attention_mask(self, attention_mask, input_shape):
......
...@@ -35,7 +35,9 @@ from .tokenization_utils import BatchEncoding ...@@ -35,7 +35,9 @@ from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {} TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all Flaubert models at https://huggingface.co/models?filter=flaubert
]
FLAUBERT_START_DOCSTRING = r""" FLAUBERT_START_DOCSTRING = r"""
...@@ -104,7 +106,6 @@ FLAUBERT_INPUTS_DOCSTRING = r""" ...@@ -104,7 +106,6 @@ FLAUBERT_INPUTS_DOCSTRING = r"""
) )
class TFFlaubertModel(TFXLMModel): class TFFlaubertModel(TFXLMModel):
config_class = FlaubertConfig config_class = FlaubertConfig
pretrained_model_archive_map = TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
...@@ -309,7 +310,6 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -309,7 +310,6 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
) )
class TFFlaubertWithLMHeadModel(TFXLMWithLMHeadModel): class TFFlaubertWithLMHeadModel(TFXLMWithLMHeadModel):
config_class = FlaubertConfig config_class = FlaubertConfig
pretrained_model_archive_map = TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
...@@ -323,7 +323,6 @@ class TFFlaubertWithLMHeadModel(TFXLMWithLMHeadModel): ...@@ -323,7 +323,6 @@ class TFFlaubertWithLMHeadModel(TFXLMWithLMHeadModel):
) )
class TFFlaubertForSequenceClassification(TFXLMForSequenceClassification): class TFFlaubertForSequenceClassification(TFXLMForSequenceClassification):
config_class = FlaubertConfig config_class = FlaubertConfig
pretrained_model_archive_map = TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
......
...@@ -37,13 +37,14 @@ from .tokenization_utils import BatchEncoding ...@@ -37,13 +37,14 @@ from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = { TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
"gpt2": "https://cdn.huggingface.co/gpt2-tf_model.h5", "gpt2",
"gpt2-medium": "https://cdn.huggingface.co/gpt2-medium-tf_model.h5", "gpt2-medium",
"gpt2-large": "https://cdn.huggingface.co/gpt2-large-tf_model.h5", "gpt2-large",
"gpt2-xl": "https://cdn.huggingface.co/gpt2-xl-tf_model.h5", "gpt2-xl",
"distilgpt2": "https://cdn.huggingface.co/distilgpt2-tf_model.h5", "distilgpt2",
} # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
]
def gelu(x): def gelu(x):
...@@ -389,7 +390,6 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel): ...@@ -389,7 +390,6 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel):
""" """
config_class = GPT2Config config_class = GPT2Config
pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -36,7 +36,10 @@ from .tokenization_utils import BatchEncoding ...@@ -36,7 +36,10 @@ from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://cdn.huggingface.co/openai-gpt-tf_model.h5"} TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"openai-gpt",
# See all OpenAI GPT models at https://huggingface.co/models?filter=openai-gpt
]
def gelu(x): def gelu(x):
...@@ -349,7 +352,6 @@ class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel): ...@@ -349,7 +352,6 @@ class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel):
""" """
config_class = OpenAIGPTConfig config_class = OpenAIGPTConfig
pretrained_model_archive_map = TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -28,12 +28,13 @@ from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list ...@@ -28,12 +28,13 @@ from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = { TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
"roberta-base": "https://cdn.huggingface.co/roberta-base-tf_model.h5", "roberta-base",
"roberta-large": "https://cdn.huggingface.co/roberta-large-tf_model.h5", "roberta-large",
"roberta-large-mnli": "https://cdn.huggingface.co/roberta-large-mnli-tf_model.h5", "roberta-large-mnli",
"distilroberta-base": "https://cdn.huggingface.co/distilroberta-base-tf_model.h5", "distilroberta-base",
} # See all RoBERTa models at https://huggingface.co/models?filter=roberta
]
class TFRobertaEmbeddings(TFBertEmbeddings): class TFRobertaEmbeddings(TFBertEmbeddings):
...@@ -100,7 +101,6 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel): ...@@ -100,7 +101,6 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel):
""" """
config_class = RobertaConfig config_class = RobertaConfig
pretrained_model_archive_map = TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "roberta" base_model_prefix = "roberta"
......
...@@ -30,13 +30,14 @@ from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list ...@@ -30,13 +30,14 @@ from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP = { TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
"t5-small": "https://cdn.huggingface.co/t5-small-tf_model.h5", "t5-small",
"t5-base": "https://cdn.huggingface.co/t5-base-tf_model.h5", "t5-base",
"t5-large": "https://cdn.huggingface.co/t5-large-tf_model.h5", "t5-large",
"t5-3b": "https://cdn.huggingface.co/t5-3b-tf_model.h5", "t5-3b",
"t5-11b": "https://cdn.huggingface.co/t5-11b-tf_model.h5", "t5-11b",
} # See all T5 models at https://huggingface.co/models?filter=t5
]
#################################################### ####################################################
# TF 2.0 Models are constructed using Keras imperative API by sub-classing # TF 2.0 Models are constructed using Keras imperative API by sub-classing
...@@ -720,7 +721,6 @@ class TFT5PreTrainedModel(TFPreTrainedModel): ...@@ -720,7 +721,6 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
""" """
config_class = T5Config config_class = T5Config
pretrained_model_archive_map = TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer" base_model_prefix = "transformer"
@property @property
......
...@@ -30,9 +30,10 @@ from .tokenization_utils import BatchEncoding ...@@ -30,9 +30,10 @@ from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = { TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = [
"transfo-xl-wt103": "https://cdn.huggingface.co/transfo-xl-wt103-tf_model.h5", "transfo-xl-wt103",
} # See all Transformer XL models at https://huggingface.co/models?filter=transfo-xl
]
class TFPositionalEmbedding(tf.keras.layers.Layer): class TFPositionalEmbedding(tf.keras.layers.Layer):
...@@ -630,7 +631,6 @@ class TFTransfoXLPreTrainedModel(TFPreTrainedModel): ...@@ -630,7 +631,6 @@ class TFTransfoXLPreTrainedModel(TFPreTrainedModel):
""" """
config_class = TransfoXLConfig config_class = TransfoXLConfig
pretrained_model_archive_map = TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -112,7 +112,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -112,7 +112,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
Class attributes (overridden by derived classes): Class attributes (overridden by derived classes):
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments: - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
- ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`, - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
...@@ -122,7 +121,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -122,7 +121,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
""" """
config_class = None config_class = None
pretrained_model_archive_map = {}
base_model_prefix = "" base_model_prefix = ""
@property @property
...@@ -338,9 +336,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -338,9 +336,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# Load model # Load model
if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path is not None:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: if os.path.isdir(pretrained_model_name_or_path):
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): if os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
# Load from a TF 2.0 checkpoint # Load from a TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
...@@ -364,8 +360,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -364,8 +360,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
use_cdn=use_cdn, use_cdn=use_cdn,
) )
# redirect to the cache, if necessary
try: try:
# Load from URL or cache if already cached
resolved_archive_file = cached_path( resolved_archive_file = cached_path(
archive_file, archive_file,
cache_dir=cache_dir, cache_dir=cache_dir,
...@@ -373,20 +369,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -373,20 +369,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
) )
except EnvironmentError as e: if resolved_archive_file is None:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: raise EnvironmentError
logger.error("Couldn't reach server at '{}' to download pretrained weights.".format(archive_file)) except EnvironmentError:
else: msg = (
logger.error( f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
"Model name '{}' was not found in model name list ({}). " f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
"We assumed '{}' was a path or url but couldn't find any file " f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n"
"associated to this path or url.".format(
pretrained_model_name_or_path,
", ".join(cls.pretrained_model_archive_map.keys()),
archive_file,
)
) )
raise e raise EnvironmentError(msg)
if resolved_archive_file == archive_file: if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file)) logger.info("loading weights file {}".format(archive_file))
else: else:
......
...@@ -31,18 +31,19 @@ from .tokenization_utils import BatchEncoding ...@@ -31,18 +31,19 @@ from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP = { TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"xlm-mlm-en-2048": "https://cdn.huggingface.co/xlm-mlm-en-2048-tf_model.h5", "xlm-mlm-en-2048",
"xlm-mlm-ende-1024": "https://cdn.huggingface.co/xlm-mlm-ende-1024-tf_model.h5", "xlm-mlm-ende-1024",
"xlm-mlm-enfr-1024": "https://cdn.huggingface.co/xlm-mlm-enfr-1024-tf_model.h5", "xlm-mlm-enfr-1024",
"xlm-mlm-enro-1024": "https://cdn.huggingface.co/xlm-mlm-enro-1024-tf_model.h5", "xlm-mlm-enro-1024",
"xlm-mlm-tlm-xnli15-1024": "https://cdn.huggingface.co/xlm-mlm-tlm-xnli15-1024-tf_model.h5", "xlm-mlm-tlm-xnli15-1024",
"xlm-mlm-xnli15-1024": "https://cdn.huggingface.co/xlm-mlm-xnli15-1024-tf_model.h5", "xlm-mlm-xnli15-1024",
"xlm-clm-enfr-1024": "https://cdn.huggingface.co/xlm-clm-enfr-1024-tf_model.h5", "xlm-clm-enfr-1024",
"xlm-clm-ende-1024": "https://cdn.huggingface.co/xlm-clm-ende-1024-tf_model.h5", "xlm-clm-ende-1024",
"xlm-mlm-17-1280": "https://cdn.huggingface.co/xlm-mlm-17-1280-tf_model.h5", "xlm-mlm-17-1280",
"xlm-mlm-100-1280": "https://cdn.huggingface.co/xlm-mlm-100-1280-tf_model.h5", "xlm-mlm-100-1280",
} # See all XLM models at https://huggingface.co/models?filter=xlm
]
def create_sinusoidal_embeddings(n_pos, dim, out): def create_sinusoidal_embeddings(n_pos, dim, out):
...@@ -470,7 +471,6 @@ class TFXLMPreTrainedModel(TFPreTrainedModel): ...@@ -470,7 +471,6 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
""" """
config_class = XLMConfig config_class = XLMConfig
pretrained_model_archive_map = TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer" base_model_prefix = "transformer"
@property @property
......
...@@ -30,7 +30,9 @@ from .modeling_tf_roberta import ( ...@@ -30,7 +30,9 @@ from .modeling_tf_roberta import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {} TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all XLM-RoBERTa models at https://huggingface.co/models?filter=xlm-roberta
]
XLM_ROBERTA_START_DOCSTRING = r""" XLM_ROBERTA_START_DOCSTRING = r"""
...@@ -72,7 +74,6 @@ class TFXLMRobertaModel(TFRobertaModel): ...@@ -72,7 +74,6 @@ class TFXLMRobertaModel(TFRobertaModel):
""" """
config_class = XLMRobertaConfig config_class = XLMRobertaConfig
pretrained_model_archive_map = TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings( @add_start_docstrings(
...@@ -85,7 +86,6 @@ class TFXLMRobertaForMaskedLM(TFRobertaForMaskedLM): ...@@ -85,7 +86,6 @@ class TFXLMRobertaForMaskedLM(TFRobertaForMaskedLM):
""" """
config_class = XLMRobertaConfig config_class = XLMRobertaConfig
pretrained_model_archive_map = TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings( @add_start_docstrings(
...@@ -100,7 +100,6 @@ class TFXLMRobertaForSequenceClassification(TFRobertaForSequenceClassification): ...@@ -100,7 +100,6 @@ class TFXLMRobertaForSequenceClassification(TFRobertaForSequenceClassification):
""" """
config_class = XLMRobertaConfig config_class = XLMRobertaConfig
pretrained_model_archive_map = TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings( @add_start_docstrings(
...@@ -115,4 +114,3 @@ class TFXLMRobertaForTokenClassification(TFRobertaForTokenClassification): ...@@ -115,4 +114,3 @@ class TFXLMRobertaForTokenClassification(TFRobertaForTokenClassification):
""" """
config_class = XLMRobertaConfig config_class = XLMRobertaConfig
pretrained_model_archive_map = TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
...@@ -37,10 +37,11 @@ from .tokenization_utils import BatchEncoding ...@@ -37,10 +37,11 @@ from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = { TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
"xlnet-base-cased": "https://cdn.huggingface.co/xlnet-base-cased-tf_model.h5", "xlnet-base-cased",
"xlnet-large-cased": "https://cdn.huggingface.co/xlnet-large-cased-tf_model.h5", "xlnet-large-cased",
} # See all XLNet models at https://huggingface.co/models?filter=xlnet
]
def gelu(x): def gelu(x):
...@@ -701,7 +702,6 @@ class TFXLNetPreTrainedModel(TFPreTrainedModel): ...@@ -701,7 +702,6 @@ class TFXLNetPreTrainedModel(TFPreTrainedModel):
""" """
config_class = XLNetConfig config_class = XLNetConfig
pretrained_model_archive_map = TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -33,9 +33,10 @@ from .modeling_utils import PreTrainedModel ...@@ -33,9 +33,10 @@ from .modeling_utils import PreTrainedModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = { TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = [
"transfo-xl-wt103": "https://cdn.huggingface.co/transfo-xl-wt103-pytorch_model.bin", "transfo-xl-wt103",
} # See all Transformer XL models at https://huggingface.co/models?filter=transfo-xl
]
def build_tf_to_pytorch_map(model, config): def build_tf_to_pytorch_map(model, config):
...@@ -453,7 +454,6 @@ class TransfoXLPreTrainedModel(PreTrainedModel): ...@@ -453,7 +454,6 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
""" """
config_class = TransfoXLConfig config_class = TransfoXLConfig
pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_transfo_xl load_tf_weights = load_tf_weights_in_transfo_xl
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -110,6 +110,9 @@ class ModuleUtilsMixin: ...@@ -110,6 +110,9 @@ class ModuleUtilsMixin:
@property @property
def device(self) -> device: def device(self) -> device:
"""
Get torch.device from module, assuming that the whole module has one device.
"""
try: try:
return next(self.parameters()).device return next(self.parameters()).device
except StopIteration: except StopIteration:
...@@ -125,6 +128,9 @@ class ModuleUtilsMixin: ...@@ -125,6 +128,9 @@ class ModuleUtilsMixin:
@property @property
def dtype(self) -> dtype: def dtype(self) -> dtype:
"""
Get torch.dtype from module, assuming that the whole module has one dtype.
"""
try: try:
return next(self.parameters()).dtype return next(self.parameters()).dtype
except StopIteration: except StopIteration:
...@@ -249,7 +255,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -249,7 +255,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
Class attributes (overridden by derived classes): Class attributes (overridden by derived classes):
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments: - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
- ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`, - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
...@@ -259,7 +264,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -259,7 +264,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
""" """
config_class = None config_class = None
pretrained_model_archive_map = {}
base_model_prefix = "" base_model_prefix = ""
@property @property
...@@ -587,9 +591,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -587,9 +591,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# Load model # Load model
if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path is not None:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: if os.path.isdir(pretrained_model_name_or_path):
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path):
if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")): if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
# Load from a TF 1.0 checkpoint # Load from a TF 1.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
...@@ -622,8 +624,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -622,8 +624,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
use_cdn=use_cdn, use_cdn=use_cdn,
) )
# redirect to the cache, if necessary
try: try:
# Load from URL or cache if already cached
resolved_archive_file = cached_path( resolved_archive_file = cached_path(
archive_file, archive_file,
cache_dir=cache_dir, cache_dir=cache_dir,
...@@ -632,19 +634,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -632,19 +634,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
) )
if resolved_archive_file is None:
raise EnvironmentError
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained weights.".format(archive_file)
else:
msg = ( msg = (
"Model name '{}' was not found in model name list ({}). " f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
"We assumed '{}' was a path or url to model weight files named one of {} but " f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
"couldn't find any such file at this path or url.".format( f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
pretrained_model_name_or_path,
", ".join(cls.pretrained_model_archive_map.keys()),
archive_file,
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME],
)
) )
raise EnvironmentError(msg) raise EnvironmentError(msg)
......
...@@ -34,18 +34,19 @@ from .modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead, prune_l ...@@ -34,18 +34,19 @@ from .modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead, prune_l
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
XLM_PRETRAINED_MODEL_ARCHIVE_MAP = { XLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"xlm-mlm-en-2048": "https://cdn.huggingface.co/xlm-mlm-en-2048-pytorch_model.bin", "xlm-mlm-en-2048",
"xlm-mlm-ende-1024": "https://cdn.huggingface.co/xlm-mlm-ende-1024-pytorch_model.bin", "xlm-mlm-ende-1024",
"xlm-mlm-enfr-1024": "https://cdn.huggingface.co/xlm-mlm-enfr-1024-pytorch_model.bin", "xlm-mlm-enfr-1024",
"xlm-mlm-enro-1024": "https://cdn.huggingface.co/xlm-mlm-enro-1024-pytorch_model.bin", "xlm-mlm-enro-1024",
"xlm-mlm-tlm-xnli15-1024": "https://cdn.huggingface.co/xlm-mlm-tlm-xnli15-1024-pytorch_model.bin", "xlm-mlm-tlm-xnli15-1024",
"xlm-mlm-xnli15-1024": "https://cdn.huggingface.co/xlm-mlm-xnli15-1024-pytorch_model.bin", "xlm-mlm-xnli15-1024",
"xlm-clm-enfr-1024": "https://cdn.huggingface.co/xlm-clm-enfr-1024-pytorch_model.bin", "xlm-clm-enfr-1024",
"xlm-clm-ende-1024": "https://cdn.huggingface.co/xlm-clm-ende-1024-pytorch_model.bin", "xlm-clm-ende-1024",
"xlm-mlm-17-1280": "https://cdn.huggingface.co/xlm-mlm-17-1280-pytorch_model.bin", "xlm-mlm-17-1280",
"xlm-mlm-100-1280": "https://cdn.huggingface.co/xlm-mlm-100-1280-pytorch_model.bin", "xlm-mlm-100-1280",
} # See all XLM models at https://huggingface.co/models?filter=xlm
]
def create_sinusoidal_embeddings(n_pos, dim, out): def create_sinusoidal_embeddings(n_pos, dim, out):
...@@ -207,7 +208,6 @@ class XLMPreTrainedModel(PreTrainedModel): ...@@ -207,7 +208,6 @@ class XLMPreTrainedModel(PreTrainedModel):
""" """
config_class = XLMConfig config_class = XLMConfig
pretrained_model_archive_map = XLM_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = None load_tf_weights = None
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -31,14 +31,15 @@ from .modeling_roberta import ( ...@@ -31,14 +31,15 @@ from .modeling_roberta import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = { XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
"xlm-roberta-base": "https://cdn.huggingface.co/xlm-roberta-base-pytorch_model.bin", "xlm-roberta-base",
"xlm-roberta-large": "https://cdn.huggingface.co/xlm-roberta-large-pytorch_model.bin", "xlm-roberta-large",
"xlm-roberta-large-finetuned-conll02-dutch": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll02-dutch-pytorch_model.bin", "xlm-roberta-large-finetuned-conll02-dutch",
"xlm-roberta-large-finetuned-conll02-spanish": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll02-spanish-pytorch_model.bin", "xlm-roberta-large-finetuned-conll02-spanish",
"xlm-roberta-large-finetuned-conll03-english": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-english-pytorch_model.bin", "xlm-roberta-large-finetuned-conll03-english",
"xlm-roberta-large-finetuned-conll03-german": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-german-pytorch_model.bin", "xlm-roberta-large-finetuned-conll03-german",
} # See all XLM-RoBERTa models at https://huggingface.co/models?filter=xlm-roberta
]
XLM_ROBERTA_START_DOCSTRING = r""" XLM_ROBERTA_START_DOCSTRING = r"""
...@@ -65,7 +66,6 @@ class XLMRobertaModel(RobertaModel): ...@@ -65,7 +66,6 @@ class XLMRobertaModel(RobertaModel):
""" """
config_class = XLMRobertaConfig config_class = XLMRobertaConfig
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings( @add_start_docstrings(
...@@ -78,7 +78,6 @@ class XLMRobertaForMaskedLM(RobertaForMaskedLM): ...@@ -78,7 +78,6 @@ class XLMRobertaForMaskedLM(RobertaForMaskedLM):
""" """
config_class = XLMRobertaConfig config_class = XLMRobertaConfig
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings( @add_start_docstrings(
...@@ -93,7 +92,6 @@ class XLMRobertaForSequenceClassification(RobertaForSequenceClassification): ...@@ -93,7 +92,6 @@ class XLMRobertaForSequenceClassification(RobertaForSequenceClassification):
""" """
config_class = XLMRobertaConfig config_class = XLMRobertaConfig
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings( @add_start_docstrings(
...@@ -108,7 +106,6 @@ class XLMRobertaForMultipleChoice(RobertaForMultipleChoice): ...@@ -108,7 +106,6 @@ class XLMRobertaForMultipleChoice(RobertaForMultipleChoice):
""" """
config_class = XLMRobertaConfig config_class = XLMRobertaConfig
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings( @add_start_docstrings(
...@@ -123,4 +120,3 @@ class XLMRobertaForTokenClassification(RobertaForTokenClassification): ...@@ -123,4 +120,3 @@ class XLMRobertaForTokenClassification(RobertaForTokenClassification):
""" """
config_class = XLMRobertaConfig config_class = XLMRobertaConfig
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
...@@ -32,10 +32,11 @@ from .modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogit ...@@ -32,10 +32,11 @@ from .modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogit
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = { XLNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
"xlnet-base-cased": "https://cdn.huggingface.co/xlnet-base-cased-pytorch_model.bin", "xlnet-base-cased",
"xlnet-large-cased": "https://cdn.huggingface.co/xlnet-large-cased-pytorch_model.bin", "xlnet-large-cased",
} # See all XLNet models at https://huggingface.co/models?filter=xlnet
]
def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None): def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
...@@ -459,7 +460,6 @@ class XLNetPreTrainedModel(PreTrainedModel): ...@@ -459,7 +460,6 @@ class XLNetPreTrainedModel(PreTrainedModel):
""" """
config_class = XLNetConfig config_class = XLNetConfig
pretrained_model_archive_map = XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_xlnet load_tf_weights = load_tf_weights_in_xlnet
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -97,27 +97,24 @@ class AutoTokenizer: ...@@ -97,27 +97,24 @@ class AutoTokenizer:
when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)`
class method. class method.
The `from_pretrained()` method take care of returning the correct tokenizer class instance The `from_pretrained()` method takes care of returning the correct tokenizer class instance
based on the `model_type` property of the config object, or when it's missing, based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string. falling back to using pattern matching on the `pretrained_model_name_or_path` string:
- `t5`: T5Tokenizer (T5 model)
The tokenizer class to instantiate is selected as the first pattern matching - `distilbert`: DistilBertTokenizer (DistilBert model)
in the `pretrained_model_name_or_path` string (in the following order): - `albert`: AlbertTokenizer (ALBERT model)
- contains `t5`: T5Tokenizer (T5 model) - `camembert`: CamembertTokenizer (CamemBERT model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model) - `xlm-roberta`: XLMRobertaTokenizer (XLM-RoBERTa model)
- contains `albert`: AlbertTokenizer (ALBERT model) - `longformer`: LongformerTokenizer (AllenAI Longformer model)
- contains `camembert`: CamembertTokenizer (CamemBERT model) - `roberta`: RobertaTokenizer (RoBERTa model)
- contains `xlm-roberta`: XLMRobertaTokenizer (XLM-RoBERTa model) - `bert`: BertTokenizer (Bert model)
- contains `longformer`: LongformerTokenizer (AllenAI Longformer model) - `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `roberta`: RobertaTokenizer (RoBERTa model) - `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `bert`: BertTokenizer (Bert model) - `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model) - `xlnet`: XLNetTokenizer (XLNet model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model) - `xlm`: XLMTokenizer (XLM model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model) - `ctrl`: CTRLTokenizer (Salesforce CTRL model)
- contains `xlnet`: XLNetTokenizer (XLNet model) - `electra`: ElectraTokenizer (Google ELECTRA model)
- contains `xlm`: XLMTokenizer (XLM model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
- contains `electra`: ElectraTokenizer (Google ELECTRA model)
This class cannot be instantiated using `__init__()` (throw an error). This class cannot be instantiated using `__init__()` (throw an error).
""" """
...@@ -133,24 +130,25 @@ class AutoTokenizer: ...@@ -133,24 +130,25 @@ class AutoTokenizer:
r""" Instantiate one of the tokenizer classes of the library r""" Instantiate one of the tokenizer classes of the library
from a pre-trained model vocabulary. from a pre-trained model vocabulary.
The tokenizer class to instantiate is selected as the first pattern matching The tokenizer class to instantiate is selected
in the `pretrained_model_name_or_path` string (in the following order): based on the `model_type` property of the config object, or when it's missing,
- contains `t5`: T5Tokenizer (T5 model) falling back to using pattern matching on the `pretrained_model_name_or_path` string:
- contains `distilbert`: DistilBertTokenizer (DistilBert model) - `t5`: T5Tokenizer (T5 model)
- contains `albert`: AlbertTokenizer (ALBERT model) - `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `camembert`: CamembertTokenizer (CamemBERT model) - `albert`: AlbertTokenizer (ALBERT model)
- contains `xlm-roberta`: XLMRobertaTokenizer (XLM-RoBERTa model) - `camembert`: CamembertTokenizer (CamemBERT model)
- contains `longformer`: LongformerTokenizer (AllenAI Longformer model) - `xlm-roberta`: XLMRobertaTokenizer (XLM-RoBERTa model)
- contains `roberta`: RobertaTokenizer (RoBERTa model) - `longformer`: LongformerTokenizer (AllenAI Longformer model)
- contains `bert-base-japanese`: BertJapaneseTokenizer (Bert model) - `roberta`: RobertaTokenizer (RoBERTa model)
- contains `bert`: BertTokenizer (Bert model) - `bert-base-japanese`: BertJapaneseTokenizer (Bert model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model) - `bert`: BertTokenizer (Bert model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model) - `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model) - `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `xlnet`: XLNetTokenizer (XLNet model) - `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `xlm`: XLMTokenizer (XLM model) - `xlnet`: XLNetTokenizer (XLNet model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model) - `xlm`: XLMTokenizer (XLM model)
- contains `electra`: ElectraTokenizer (Google ELECTRA model) - `ctrl`: CTRLTokenizer (Salesforce CTRL model)
- `electra`: ElectraTokenizer (Google ELECTRA model)
Params: Params:
pretrained_model_name_or_path: either: pretrained_model_name_or_path: either:
......
...@@ -47,9 +47,9 @@ PRETRAINED_VOCAB_FILES_MAP = { ...@@ -47,9 +47,9 @@ PRETRAINED_VOCAB_FILES_MAP = {
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
"bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt", "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
"bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt", "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
"bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt", "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
"bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt", "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
"bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt", "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt",
} }
} }
...@@ -69,9 +69,9 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { ...@@ -69,9 +69,9 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"bert-base-cased-finetuned-mrpc": 512, "bert-base-cased-finetuned-mrpc": 512,
"bert-base-german-dbmdz-cased": 512, "bert-base-german-dbmdz-cased": 512,
"bert-base-german-dbmdz-uncased": 512, "bert-base-german-dbmdz-uncased": 512,
"bert-base-finnish-cased-v1": 512, "TurkuNLP/bert-base-finnish-cased-v1": 512,
"bert-base-finnish-uncased-v1": 512, "TurkuNLP/bert-base-finnish-uncased-v1": 512,
"bert-base-dutch-cased": 512, "wietsedv/bert-base-dutch-cased": 512,
} }
PRETRAINED_INIT_CONFIGURATION = { PRETRAINED_INIT_CONFIGURATION = {
...@@ -90,9 +90,9 @@ PRETRAINED_INIT_CONFIGURATION = { ...@@ -90,9 +90,9 @@ PRETRAINED_INIT_CONFIGURATION = {
"bert-base-cased-finetuned-mrpc": {"do_lower_case": False}, "bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
"bert-base-german-dbmdz-cased": {"do_lower_case": False}, "bert-base-german-dbmdz-cased": {"do_lower_case": False},
"bert-base-german-dbmdz-uncased": {"do_lower_case": True}, "bert-base-german-dbmdz-uncased": {"do_lower_case": True},
"bert-base-finnish-cased-v1": {"do_lower_case": False}, "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
"bert-base-finnish-uncased-v1": {"do_lower_case": True}, "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
"bert-base-dutch-cased": {"do_lower_case": False}, "wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
} }
......
...@@ -30,37 +30,37 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} ...@@ -30,37 +30,37 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": { "vocab_file": {
"bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/vocab.txt", "cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/vocab.txt",
"bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/vocab.txt", "cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/vocab.txt",
"bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/vocab.txt", "cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/vocab.txt",
"bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/vocab.txt", "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/vocab.txt",
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"bert-base-japanese": 512, "cl-tohoku/bert-base-japanese": 512,
"bert-base-japanese-whole-word-masking": 512, "cl-tohoku/bert-base-japanese-whole-word-masking": 512,
"bert-base-japanese-char": 512, "cl-tohoku/bert-base-japanese-char": 512,
"bert-base-japanese-char-whole-word-masking": 512, "cl-tohoku/bert-base-japanese-char-whole-word-masking": 512,
} }
PRETRAINED_INIT_CONFIGURATION = { PRETRAINED_INIT_CONFIGURATION = {
"bert-base-japanese": { "cl-tohoku/bert-base-japanese": {
"do_lower_case": False, "do_lower_case": False,
"word_tokenizer_type": "mecab", "word_tokenizer_type": "mecab",
"subword_tokenizer_type": "wordpiece", "subword_tokenizer_type": "wordpiece",
}, },
"bert-base-japanese-whole-word-masking": { "cl-tohoku/bert-base-japanese-whole-word-masking": {
"do_lower_case": False, "do_lower_case": False,
"word_tokenizer_type": "mecab", "word_tokenizer_type": "mecab",
"subword_tokenizer_type": "wordpiece", "subword_tokenizer_type": "wordpiece",
}, },
"bert-base-japanese-char": { "cl-tohoku/bert-base-japanese-char": {
"do_lower_case": False, "do_lower_case": False,
"word_tokenizer_type": "mecab", "word_tokenizer_type": "mecab",
"subword_tokenizer_type": "character", "subword_tokenizer_type": "character",
}, },
"bert-base-japanese-char-whole-word-masking": { "cl-tohoku/bert-base-japanese-char-whole-word-masking": {
"do_lower_case": False, "do_lower_case": False,
"word_tokenizer_type": "mecab", "word_tokenizer_type": "mecab",
"subword_tokenizer_type": "character", "subword_tokenizer_type": "character",
......
...@@ -942,13 +942,11 @@ class PreTrainedTokenizer(SpecialTokensMixin): ...@@ -942,13 +942,11 @@ class PreTrainedTokenizer(SpecialTokensMixin):
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
if len(cls.vocab_files_names) > 1: if len(cls.vocab_files_names) > 1:
raise ValueError( raise ValueError(
"Calling {}.from_pretrained() with the path to a single file or url is not supported." f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not supported."
"Use a model identifier or the path to a directory instead.".format(cls.__name__) "Use a model identifier or the path to a directory instead."
) )
logger.warning( logger.warning(
"Calling {}.from_pretrained() with the path to a single file or url is deprecated".format( f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is deprecated"
cls.__name__
)
) )
file_id = list(cls.vocab_files_names.keys())[0] file_id = list(cls.vocab_files_names.keys())[0]
vocab_files[file_id] = pretrained_model_name_or_path vocab_files[file_id] = pretrained_model_name_or_path
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment