"src/py/vscode:/vscode.git/clone" did not exist on "48f9e6ad560c3da287595af5bf347c03dd33d07b"
Unverified Commit d4c2cb40 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Kill model archive maps (#4636)

* Kill model archive maps

* Fixup

* Also kill model_archive_map for MaskedBertPreTrainedModel

* Unhook config_archive_map

* Tokenizers: align with model id changes

* make style && make quality

* Fix CI
parent 47a551d1
......@@ -36,5 +36,4 @@ class CamembertConfig(RobertaConfig):
superclass for the appropriate documentation alongside usage examples.
"""
pretrained_config_archive_map = CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "camembert"
......@@ -27,7 +27,7 @@ CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf
class CTRLConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of an :class:`~transformers.CTRLModel`.
This is the configuration class to store the configuration of a :class:`~transformers.CTRLModel`.
It is used to instantiate an CTRL model according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
the `ctrl <https://huggingface.co/ctrl>`__ architecture from SalesForce.
......@@ -76,13 +76,8 @@ class CTRLConfig(PretrainedConfig):
# Accessing the model configuration
configuration = model.config
Attributes:
pretrained_config_archive_map (Dict[str, str]):
A dictionary containing all the available pre-trained checkpoints.
"""
pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "ctrl"
def __init__(
......
......@@ -90,12 +90,7 @@ class DistilBertConfig(PretrainedConfig):
# Accessing the model configuration
configuration = model.config
Attributes:
pretrained_config_archive_map (Dict[str, str]):
A dictionary containing all the available pre-trained checkpoints.
"""
pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "distilbert"
def __init__(
......
......@@ -89,12 +89,7 @@ class ElectraConfig(PretrainedConfig):
# Accessing the model configuration
configuration = model.config
Attributes:
pretrained_config_archive_map (Dict[str, str]):
A dictionary containing all the available pre-trained checkpoints.
"""
pretrained_config_archive_map = ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "electra"
def __init__(
......
......@@ -23,10 +23,10 @@ from .configuration_xlm import XLMConfig
logger = logging.getLogger(__name__)
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"flaubert-small-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_small_cased/config.json",
"flaubert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_uncased/config.json",
"flaubert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_cased/config.json",
"flaubert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_large_cased/config.json",
"flaubert/flaubert_small_cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_small_cased/config.json",
"flaubert/flaubert_base_uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_uncased/config.json",
"flaubert/flaubert_base_cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_cased/config.json",
"flaubert/flaubert_large_cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_large_cased/config.json",
}
......@@ -142,7 +142,6 @@ class FlaubertConfig(XLMConfig):
text in a given language.
"""
pretrained_config_archive_map = FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "flaubert"
def __init__(self, layerdrop=0.0, pre_norm=False, pad_token_id=2, bos_token_id=0, **kwargs):
......
......@@ -110,13 +110,8 @@ class GPT2Config(PretrainedConfig):
# Accessing the model configuration
configuration = model.config
Attributes:
pretrained_config_archive_map (Dict[str, str]):
A dictionary containing all the available pre-trained checkpoints.
"""
pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "gpt2"
def __init__(
......
......@@ -33,7 +33,7 @@ LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
class LongformerConfig(RobertaConfig):
r"""
This is the configuration class to store the configuration of an :class:`~transformers.LongformerModel`.
This is the configuration class to store the configuration of a :class:`~transformers.LongformerModel`.
It is used to instantiate an Longformer model according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
the RoBERTa `roberta-base <https://huggingface.co/roberta-base>`__ architecture with a sequence length 4,096.
......@@ -59,12 +59,7 @@ class LongformerConfig(RobertaConfig):
# Accessing the model configuration
configuration = model.config
Attributes:
pretrained_config_archive_map (Dict[str, str]):
A dictionary containing all the available pre-trained checkpoints.
"""
pretrained_config_archive_map = LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "longformer"
def __init__(self, attention_window: Union[List[int], int] = 512, sep_token_id: int = 2, **kwargs):
......
......@@ -18,10 +18,9 @@ from .configuration_bart import BartConfig
PRETRAINED_CONFIG_ARCHIVE_MAP = {
"marian-en-de": "https://s3.amazonaws.com/models.huggingface.co/bert/Helsinki-NLP/opus-mt-en-de/config.json",
"Helsinki-NLP/opus-mt-en-de": "https://s3.amazonaws.com/models.huggingface.co/bert/Helsinki-NLP/opus-mt-en-de/config.json",
}
class MarianConfig(BartConfig):
model_type = "marian"
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
......@@ -30,7 +30,7 @@ OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
class OpenAIGPTConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of an :class:`~transformers.OpenAIGPTModel`.
This is the configuration class to store the configuration of a :class:`~transformers.OpenAIGPTModel`.
It is used to instantiate an GPT model according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
the `GPT <https://huggingface.co/openai-gpt>`__ architecture from OpenAI.
......@@ -108,13 +108,8 @@ class OpenAIGPTConfig(PretrainedConfig):
# Accessing the model configuration
configuration = model.config
Attributes:
pretrained_config_archive_map (Dict[str, str]):
A dictionary containing all the available pre-trained checkpoints.
"""
pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "openai-gpt"
def __init__(
......
......@@ -135,12 +135,7 @@ class ReformerConfig(PretrainedConfig):
# Accessing the model configuration
configuration = model.config
Attributes:
pretrained_config_archive_map (Dict[str, str]):
A dictionary containing all the available pre-trained checkpoints.
"""
pretrained_config_archive_map = REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "reformer"
def __init__(
......
......@@ -35,7 +35,7 @@ ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
class RobertaConfig(BertConfig):
r"""
This is the configuration class to store the configuration of an :class:`~transformers.RobertaModel`.
This is the configuration class to store the configuration of a :class:`~transformers.RobertaModel`.
It is used to instantiate an RoBERTa model according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
the BERT `bert-base-uncased <https://huggingface.co/bert-base-uncased>`__ architecture.
......@@ -59,12 +59,7 @@ class RobertaConfig(BertConfig):
# Accessing the model configuration
configuration = model.config
Attributes:
pretrained_config_archive_map (Dict[str, str]):
A dictionary containing all the available pre-trained checkpoints.
"""
pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "roberta"
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs):
......
......@@ -59,7 +59,6 @@ class T5Config(PretrainedConfig):
initializer_factor: A factor for initializing all weight matrices (should be kept to 1.0, used for initialization testing).
layer_norm_eps: The epsilon used by LayerNorm.
"""
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "t5"
def __init__(
......
......@@ -30,7 +30,7 @@ TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
class TransfoXLConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of an :class:`~transformers.TransfoXLModel`.
This is the configuration class to store the configuration of a :class:`~transformers.TransfoXLModel`.
It is used to instantiate a Transformer XL model according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
the `Transformer XL <https://huggingface.co/transfo-xl-wt103>`__ architecture.
......@@ -110,13 +110,8 @@ class TransfoXLConfig(PretrainedConfig):
# Accessing the model configuration
configuration = model.config
Attributes:
pretrained_config_archive_map (Dict[str, str]):
A dictionary containing all the available pre-trained checkpoints.
"""
pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "transfo-xl"
def __init__(
......
......@@ -20,7 +20,7 @@ import copy
import json
import logging
import os
from typing import Dict, Optional, Tuple
from typing import Dict, Tuple
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
......@@ -37,7 +37,6 @@ class PretrainedConfig(object):
It only affects the model's configuration.
Class attributes (overridden by derived classes):
- ``pretrained_config_archive_map``: a python ``dict`` with `shortcut names` (string) as keys and `url` (string) of associated pretrained model configurations as values.
- ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`.
Args:
......@@ -52,7 +51,6 @@ class PretrainedConfig(object):
torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
Is the model used with Torchscript (for PyTorch models).
"""
pretrained_config_archive_map: Dict[str, str] = {}
model_type: str = ""
def __init__(self, **kwargs):
......@@ -204,9 +202,7 @@ class PretrainedConfig(object):
return cls.from_dict(config_dict, **kwargs)
@classmethod
def get_config_dict(
cls, pretrained_model_name_or_path: str, pretrained_config_archive_map: Optional[Dict] = None, **kwargs
) -> Tuple[Dict, Dict]:
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict, Dict]:
"""
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used
for instantiating a Config using `from_dict`.
......@@ -214,8 +210,6 @@ class PretrainedConfig(object):
Parameters:
pretrained_model_name_or_path (:obj:`string`):
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
pretrained_config_archive_map: (:obj:`Dict[str, str]`, `optional`) Dict:
A map of `shortcut names` to `url`. By default, will use the current class attribute.
Returns:
:obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object.
......@@ -227,12 +221,7 @@ class PretrainedConfig(object):
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
if pretrained_config_archive_map is None:
pretrained_config_archive_map = cls.pretrained_config_archive_map
if pretrained_model_name_or_path in pretrained_config_archive_map:
config_file = pretrained_config_archive_map[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
......@@ -255,20 +244,10 @@ class PretrainedConfig(object):
config_dict = cls._dict_from_json_file(resolved_config_file)
except EnvironmentError:
if pretrained_model_name_or_path in pretrained_config_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file
)
else:
msg = (
"Can't load '{}'. Make sure that:\n\n"
"- '{}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
"- or '{}' is the correct path to a directory containing a '{}' file\n\n".format(
pretrained_model_name_or_path,
pretrained_model_name_or_path,
pretrained_model_name_or_path,
CONFIG_NAME,
)
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
)
raise EnvironmentError(msg)
......
......@@ -152,13 +152,8 @@ class XLMConfig(PretrainedConfig):
# Accessing the model configuration
configuration = model.config
Attributes:
pretrained_config_archive_map (Dict[str, str]):
A dictionary containing all the available pre-trained checkpoints.
"""
pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "xlm"
def __init__(
......
......@@ -39,5 +39,4 @@ class XLMRobertaConfig(RobertaConfig):
superclass for the appropriate documentation alongside usage examples.
"""
pretrained_config_archive_map = XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "xlm-roberta"
......@@ -122,13 +122,8 @@ class XLNetConfig(PretrainedConfig):
# Accessing the model configuration
configuration = model.config
Attributes:
pretrained_config_archive_map (Dict[str, str]):
A dictionary containing all the available pre-trained checkpoints.
"""
pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "xlnet"
def __init__(
......
......@@ -32,6 +32,7 @@ from transformers import (
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
WEIGHTS_NAME,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
......@@ -70,6 +71,7 @@ from transformers import (
XLMRobertaConfig,
XLNetConfig,
cached_path,
hf_bucket_url,
is_torch_available,
load_pytorch_checkpoint_in_tf2_model,
)
......@@ -82,261 +84,103 @@ if is_torch_available():
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNetLMHeadModel,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLMWithLMHeadModel,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
XLMRobertaForMaskedLM,
TransfoXLLMHeadModel,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM,
RobertaForSequenceClassification,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
CamembertForMaskedLM,
CamembertForSequenceClassification,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
FlaubertWithLMHeadModel,
DistilBertForMaskedLM,
DistilBertForQuestionAnswering,
DistilBertForSequenceClassification,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForPreTraining,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5ForConditionalGeneration,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
ElectraForPreTraining,
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
)
else:
(
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNetLMHeadModel,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLMWithLMHeadModel,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
XLMRobertaForMaskedLM,
TransfoXLLMHeadModel,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM,
RobertaForSequenceClassification,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
CamembertForMaskedLM,
CamembertForSequenceClassification,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
FlaubertWithLMHeadModel,
DistilBertForMaskedLM,
DistilBertForSequenceClassification,
DistilBertForQuestionAnswering,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForPreTraining,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5ForConditionalGeneration,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
ElectraForPreTraining,
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
) = (
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
logging.basicConfig(level=logging.INFO)
MODEL_CLASSES = {
"bert": (
BertConfig,
TFBertForPreTraining,
BertForPreTraining,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"bert": (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,),
"bert-large-uncased-whole-word-masking-finetuned-squad": (
BertConfig,
TFBertForQuestionAnswering,
BertForQuestionAnswering,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"bert-large-cased-whole-word-masking-finetuned-squad": (
BertConfig,
TFBertForQuestionAnswering,
BertForQuestionAnswering,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"bert-base-cased-finetuned-mrpc": (
BertConfig,
TFBertForSequenceClassification,
BertForSequenceClassification,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"gpt2": (
GPT2Config,
TFGPT2LMHeadModel,
GPT2LMHeadModel,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"xlnet": (
XLNetConfig,
TFXLNetLMHeadModel,
XLNetLMHeadModel,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"xlm": (
XLMConfig,
TFXLMWithLMHeadModel,
XLMWithLMHeadModel,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"gpt2": (GPT2Config, TFGPT2LMHeadModel, GPT2LMHeadModel, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,),
"xlnet": (XLNetConfig, TFXLNetLMHeadModel, XLNetLMHeadModel, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,),
"xlm": (XLMConfig, TFXLMWithLMHeadModel, XLMWithLMHeadModel, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,),
"xlm-roberta": (
XLMRobertaConfig,
TFXLMRobertaForMaskedLM,
XLMRobertaForMaskedLM,
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"transfo-xl": (
TransfoXLConfig,
TFTransfoXLLMHeadModel,
TransfoXLLMHeadModel,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"openai-gpt": (
OpenAIGPTConfig,
TFOpenAIGPTLMHeadModel,
OpenAIGPTLMHeadModel,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"roberta": (
RobertaConfig,
TFRobertaForMaskedLM,
RobertaForMaskedLM,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"roberta": (RobertaConfig, TFRobertaForMaskedLM, RobertaForMaskedLM, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,),
"roberta-large-mnli": (
RobertaConfig,
TFRobertaForSequenceClassification,
RobertaForSequenceClassification,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"camembert": (
CamembertConfig,
TFCamembertForMaskedLM,
CamembertForMaskedLM,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"flaubert": (
FlaubertConfig,
TFFlaubertWithLMHeadModel,
FlaubertWithLMHeadModel,
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"distilbert": (
DistilBertConfig,
TFDistilBertForMaskedLM,
DistilBertForMaskedLM,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"distilbert-base-distilled-squad": (
DistilBertConfig,
TFDistilBertForQuestionAnswering,
DistilBertForQuestionAnswering,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"ctrl": (
CTRLConfig,
TFCTRLLMHeadModel,
CTRLLMHeadModel,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"albert": (
AlbertConfig,
TFAlbertForPreTraining,
AlbertForPreTraining,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"t5": (
T5Config,
TFT5ForConditionalGeneration,
T5ForConditionalGeneration,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"electra": (
ElectraConfig,
TFElectraForPreTraining,
ElectraForPreTraining,
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"ctrl": (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,),
"albert": (AlbertConfig, TFAlbertForPreTraining, AlbertForPreTraining, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,),
"t5": (T5Config, TFT5ForConditionalGeneration, T5ForConditionalGeneration, T5_PRETRAINED_CONFIG_ARCHIVE_MAP,),
"electra": (ElectraConfig, TFElectraForPreTraining, ElectraForPreTraining, ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,),
}
......@@ -346,7 +190,7 @@ def convert_pt_checkpoint_to_tf(
if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
config_class, model_class, pt_model_class, aws_config_map = MODEL_CLASSES[model_type]
# Initialise TF model
if config_file in aws_config_map:
......@@ -358,10 +202,9 @@ def convert_pt_checkpoint_to_tf(
tf_model = model_class(config)
# Load weights from tf checkpoint
if pytorch_checkpoint_path in aws_model_maps:
pytorch_checkpoint_path = cached_path(
aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models
)
if pytorch_checkpoint_path in aws_config_map.keys():
pytorch_checkpoint_url = hf_bucket_url(pytorch_checkpoint_path, filename=WEIGHTS_NAME)
pytorch_checkpoint_path = cached_path(pytorch_checkpoint_url, force_download=not use_cached_models)
# Load PyTorch checkpoint in tf2 model:
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
......
......@@ -31,16 +31,17 @@ from .modeling_utils import PreTrainedModel
logger = logging.getLogger(__name__)
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
"albert-base-v1": "https://cdn.huggingface.co/albert-base-v1-pytorch_model.bin",
"albert-large-v1": "https://cdn.huggingface.co/albert-large-v1-pytorch_model.bin",
"albert-xlarge-v1": "https://cdn.huggingface.co/albert-xlarge-v1-pytorch_model.bin",
"albert-xxlarge-v1": "https://cdn.huggingface.co/albert-xxlarge-v1-pytorch_model.bin",
"albert-base-v2": "https://cdn.huggingface.co/albert-base-v2-pytorch_model.bin",
"albert-large-v2": "https://cdn.huggingface.co/albert-large-v2-pytorch_model.bin",
"albert-xlarge-v2": "https://cdn.huggingface.co/albert-xlarge-v2-pytorch_model.bin",
"albert-xxlarge-v2": "https://cdn.huggingface.co/albert-xxlarge-v2-pytorch_model.bin",
}
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"albert-base-v1",
"albert-large-v1",
"albert-xlarge-v1",
"albert-xxlarge-v1",
"albert-base-v2",
"albert-large-v2",
"albert-xlarge-v2",
"albert-xxlarge-v2",
# See all ALBERT models at https://huggingface.co/models?filter=albert
]
def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
......@@ -365,7 +366,6 @@ class AlbertPreTrainedModel(PreTrainedModel):
"""
config_class = AlbertConfig
pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "albert"
def _init_weights(self, module):
......@@ -439,7 +439,6 @@ ALBERT_INPUTS_DOCSTRING = r"""
class AlbertModel(AlbertPreTrainedModel):
config_class = AlbertConfig
pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_albert
base_model_prefix = "albert"
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment