Unverified Commit d4c2cb40 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Kill model archive maps (#4636)

* Kill model archive maps

* Fixup

* Also kill model_archive_map for MaskedBertPreTrainedModel

* Unhook config_archive_map

* Tokenizers: align with model id changes

* make style && make quality

* Fix CI
parent 47a551d1
...@@ -63,33 +63,33 @@ For a list that includes community-uploaded models, refer to `https://huggingfac ...@@ -63,33 +63,33 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
| | | | Trained on uncased German text by DBMDZ | | | | | Trained on uncased German text by DBMDZ |
| | | (see `details on dbmdz repository <https://github.com/dbmdz/german-bert>`__). | | | | (see `details on dbmdz repository <https://github.com/dbmdz/german-bert>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-japanese`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | | | ``cl-tohoku/bert-base-japanese`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on Japanese text. Text is tokenized with MeCab and WordPiece. | | | | | Trained on Japanese text. Text is tokenized with MeCab and WordPiece. |
| | | | `MeCab <https://taku910.github.io/mecab/>`__ is required for tokenization. | | | | | `MeCab <https://taku910.github.io/mecab/>`__ is required for tokenization. |
| | | (see `details on cl-tohoku repository <https://github.com/cl-tohoku/bert-japanese>`__). | | | | (see `details on cl-tohoku repository <https://github.com/cl-tohoku/bert-japanese>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-japanese-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | | | ``cl-tohoku/bert-base-japanese-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on Japanese text using Whole-Word-Masking. Text is tokenized with MeCab and WordPiece. | | | | | Trained on Japanese text using Whole-Word-Masking. Text is tokenized with MeCab and WordPiece. |
| | | | `MeCab <https://taku910.github.io/mecab/>`__ is required for tokenization. | | | | | `MeCab <https://taku910.github.io/mecab/>`__ is required for tokenization. |
| | | (see `details on cl-tohoku repository <https://github.com/cl-tohoku/bert-japanese>`__). | | | | (see `details on cl-tohoku repository <https://github.com/cl-tohoku/bert-japanese>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-japanese-char`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | | | ``cl-tohoku/bert-base-japanese-char`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on Japanese text. Text is tokenized into characters. | | | | | Trained on Japanese text. Text is tokenized into characters. |
| | | (see `details on cl-tohoku repository <https://github.com/cl-tohoku/bert-japanese>`__). | | | | (see `details on cl-tohoku repository <https://github.com/cl-tohoku/bert-japanese>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-japanese-char-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | | | ``cl-tohoku/bert-base-japanese-char-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on Japanese text using Whole-Word-Masking. Text is tokenized into characters. | | | | | Trained on Japanese text using Whole-Word-Masking. Text is tokenized into characters. |
| | | (see `details on cl-tohoku repository <https://github.com/cl-tohoku/bert-japanese>`__). | | | | (see `details on cl-tohoku repository <https://github.com/cl-tohoku/bert-japanese>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-finnish-cased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | | | ``TurkuNLP/bert-base-finnish-cased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on cased Finnish text. | | | | | Trained on cased Finnish text. |
| | | (see `details on turkunlp.org <http://turkunlp.org/FinBERT/>`__). | | | | (see `details on turkunlp.org <http://turkunlp.org/FinBERT/>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-finnish-uncased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | | | ``TurkuNLP/bert-base-finnish-uncased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on uncased Finnish text. | | | | | Trained on uncased Finnish text. |
| | | (see `details on turkunlp.org <http://turkunlp.org/FinBERT/>`__). | | | | (see `details on turkunlp.org <http://turkunlp.org/FinBERT/>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-dutch-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | | | ``wietsedv/bert-base-dutch-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on cased Dutch text. | | | | | Trained on cased Dutch text. |
| | | (see `details on wietsedv repository <https://github.com/wietsedv/bertje/>`__). | | | | (see `details on wietsedv repository <https://github.com/wietsedv/bertje/>`__). |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
...@@ -259,32 +259,32 @@ For a list that includes community-uploaded models, refer to `https://huggingfac ...@@ -259,32 +259,32 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
| | ``xlm-roberta-large`` | | ~355M parameters with 24-layers, 1027-hidden-state, 4096 feed-forward hidden-state, 16-heads, | | | ``xlm-roberta-large`` | | ~355M parameters with 24-layers, 1027-hidden-state, 4096 feed-forward hidden-state, 16-heads, |
| | | | Trained on 2.5 TB of newly created clean CommonCrawl data in 100 languages | | | | | Trained on 2.5 TB of newly created clean CommonCrawl data in 100 languages |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| FlauBERT | ``flaubert-small-cased`` | | 6-layer, 512-hidden, 8-heads, 54M parameters | | FlauBERT | ``flaubert/flaubert_small_cased`` | | 6-layer, 512-hidden, 8-heads, 54M parameters |
| | | | FlauBERT small architecture | | | | | FlauBERT small architecture |
| | | (see `details <https://github.com/getalp/Flaubert>`__) | | | | (see `details <https://github.com/getalp/Flaubert>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``flaubert-base-uncased`` | | 12-layer, 768-hidden, 12-heads, 137M parameters | | | ``flaubert/flaubert_base_uncased`` | | 12-layer, 768-hidden, 12-heads, 137M parameters |
| | | | FlauBERT base architecture with uncased vocabulary | | | | | FlauBERT base architecture with uncased vocabulary |
| | | (see `details <https://github.com/getalp/Flaubert>`__) | | | | (see `details <https://github.com/getalp/Flaubert>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``flaubert-base-cased`` | | 12-layer, 768-hidden, 12-heads, 138M parameters | | | ``flaubert/flaubert_base_cased`` | | 12-layer, 768-hidden, 12-heads, 138M parameters |
| | | | FlauBERT base architecture with cased vocabulary | | | | | FlauBERT base architecture with cased vocabulary |
| | | (see `details <https://github.com/getalp/Flaubert>`__) | | | | (see `details <https://github.com/getalp/Flaubert>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``flaubert-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 373M parameters | | | ``flaubert/flaubert_large_cased`` | | 24-layer, 1024-hidden, 16-heads, 373M parameters |
| | | | FlauBERT large architecture | | | | | FlauBERT large architecture |
| | | (see `details <https://github.com/getalp/Flaubert>`__) | | | | (see `details <https://github.com/getalp/Flaubert>`__) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Bart | ``bart-large`` | | 24-layer, 1024-hidden, 16-heads, 406M parameters | | Bart | ``facebook/bart-large`` | | 24-layer, 1024-hidden, 16-heads, 406M parameters |
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_) | | | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters | | | ``facebook/bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters |
| | | | bart-large base architecture with a classification head, finetuned on MNLI | | | | | bart-large base architecture with a classification head, finetuned on MNLI |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bart-large-cnn`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters (same as base) | | | ``facebook/bart-large-cnn`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters (same as base) |
| | | | bart-large base architecture finetuned on cnn summarization task | | | | | bart-large base architecture finetuned on cnn summarization task |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``mbart-large-en-ro`` | | 12-layer, 1024-hidden, 16-heads, 880M parameters | | | ``facebook/mbart-large-en-ro`` | | 12-layer, 1024-hidden, 16-heads, 880M parameters |
| | | | bart-large architecture pretrained on cc25 multilingual data , finetuned on WMT english romanian translation. | | | | | bart-large architecture pretrained on cc25 multilingual data , finetuned on WMT english romanian translation. |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| DialoGPT | ``DialoGPT-small`` | | 12-layer, 768-hidden, 12-heads, 124M parameters | | DialoGPT | ``DialoGPT-small`` | | 12-layer, 768-hidden, 12-heads, 124M parameters |
...@@ -305,9 +305,9 @@ For a list that includes community-uploaded models, refer to `https://huggingfac ...@@ -305,9 +305,9 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
| MarianMT | ``Helsinki-NLP/opus-mt-{src}-{tgt}`` | | 12-layer, 512-hidden, 8-heads, ~74M parameter Machine translation models. Parameter counts vary depending on vocab size. | | MarianMT | ``Helsinki-NLP/opus-mt-{src}-{tgt}`` | | 12-layer, 512-hidden, 8-heads, ~74M parameter Machine translation models. Parameter counts vary depending on vocab size. |
| | | | (see `model list <https://huggingface.co/Helsinki-NLP>`_) | | | | | (see `model list <https://huggingface.co/Helsinki-NLP>`_) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Longformer | ``longformer-base-4096`` | | 12-layer, 768-hidden, 12-heads, ~149M parameters | | Longformer | ``allenai/longformer-base-4096`` | | 12-layer, 768-hidden, 12-heads, ~149M parameters |
| | | | Starting from RoBERTa-base checkpoint, trained on documents of max length 4,096 | | | | | Starting from RoBERTa-base checkpoint, trained on documents of max length 4,096 |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``longformer-large-4096`` | | 24-layer, 1024-hidden, 16-heads, ~435M parameters | | | ``allenai/longformer-large-4096`` | | 24-layer, 1024-hidden, 16-heads, ~435M parameters |
| | | | Starting from RoBERTa-large checkpoint, trained on documents of max length 4,096 | | | | | Starting from RoBERTa-large checkpoint, trained on documents of max length 4,096 |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
...@@ -65,13 +65,6 @@ except ImportError: ...@@ -65,13 +65,6 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum(
(
tuple(conf.pretrained_config_archive_map.keys())
for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig)
),
(),
)
MODEL_CLASSES = { MODEL_CLASSES = {
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer), "bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
...@@ -389,7 +382,7 @@ def main(): ...@@ -389,7 +382,7 @@ def main():
default=None, default=None,
type=str, type=str,
required=True, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS), help="Path to pretrained model or model identifier from huggingface.co/models",
) )
parser.add_argument( parser.add_argument(
"--task_name", "--task_name",
......
...@@ -34,26 +34,11 @@ from tqdm import tqdm, trange ...@@ -34,26 +34,11 @@ from tqdm import tqdm, trange
from transformers import ( from transformers import (
WEIGHTS_NAME, WEIGHTS_NAME,
AdamW, AdamW,
AlbertConfig, AutoConfig,
AlbertModel, AutoModel,
AlbertTokenizer, AutoTokenizer,
BertConfig,
BertModel,
BertTokenizer,
DistilBertConfig,
DistilBertModel,
DistilBertTokenizer,
MMBTConfig, MMBTConfig,
MMBTForClassification, MMBTForClassification,
RobertaConfig,
RobertaModel,
RobertaTokenizer,
XLMConfig,
XLMModel,
XLMTokenizer,
XLNetConfig,
XLNetModel,
XLNetTokenizer,
get_linear_schedule_with_warmup, get_linear_schedule_with_warmup,
) )
from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_image_transforms, get_mmimdb_labels from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_image_transforms, get_mmimdb_labels
...@@ -67,23 +52,6 @@ except ImportError: ...@@ -67,23 +52,6 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum(
(
tuple(conf.pretrained_config_archive_map.keys())
for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig)
),
(),
)
MODEL_CLASSES = {
"bert": (BertConfig, BertModel, BertTokenizer),
"xlnet": (XLNetConfig, XLNetModel, XLNetTokenizer),
"xlm": (XLMConfig, XLMModel, XLMTokenizer),
"roberta": (RobertaConfig, RobertaModel, RobertaTokenizer),
"distilbert": (DistilBertConfig, DistilBertModel, DistilBertTokenizer),
"albert": (AlbertConfig, AlbertModel, AlbertTokenizer),
}
def set_seed(args): def set_seed(args):
random.seed(args.seed) random.seed(args.seed)
...@@ -351,19 +319,12 @@ def main(): ...@@ -351,19 +319,12 @@ def main():
required=True, required=True,
help="The input data dir. Should contain the .jsonl files for MMIMDB.", help="The input data dir. Should contain the .jsonl files for MMIMDB.",
) )
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument( parser.add_argument(
"--model_name_or_path", "--model_name_or_path",
default=None, default=None,
type=str, type=str,
required=True, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS), help="Path to pretrained model or model identifier from huggingface.co/models",
) )
parser.add_argument( parser.add_argument(
"--output_dir", "--output_dir",
...@@ -385,7 +346,7 @@ def main(): ...@@ -385,7 +346,7 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--cache_dir", "--cache_dir",
default="", default=None,
type=str, type=str,
help="Where do you want to store the pre-trained models downloaded from s3", help="Where do you want to store the pre-trained models downloaded from s3",
) )
...@@ -526,18 +487,14 @@ def main(): ...@@ -526,18 +487,14 @@ def main():
# Setup model # Setup model
labels = get_mmimdb_labels() labels = get_mmimdb_labels()
num_labels = len(labels) num_labels = len(labels)
args.model_type = args.model_type.lower() transformer_config = AutoConfig.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] tokenizer = AutoTokenizer.from_pretrained(
transformer_config = config_class.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path
)
tokenizer = tokenizer_class.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case, do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None, cache_dir=args.cache_dir,
) )
transformer = model_class.from_pretrained( transformer = AutoModel.from_pretrained(
args.model_name_or_path, config=transformer_config, cache_dir=args.cache_dir if args.cache_dir else None args.model_name_or_path, config=transformer_config, cache_dir=args.cache_dir
) )
img_encoder = ImageEncoder(args) img_encoder = ImageEncoder(args)
config = MMBTConfig(transformer_config, num_labels=num_labels) config = MMBTConfig(transformer_config, num_labels=num_labels)
...@@ -583,13 +540,12 @@ def main(): ...@@ -583,13 +540,12 @@ def main():
# Load a trained model and vocabulary that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
model = MMBTForClassification(config, transformer, img_encoder) model = MMBTForClassification(config, transformer, img_encoder)
model.load_state_dict(torch.load(os.path.join(args.output_dir, WEIGHTS_NAME))) model.load_state_dict(torch.load(os.path.join(args.output_dir, WEIGHTS_NAME)))
tokenizer = tokenizer_class.from_pretrained(args.output_dir) tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
model.to(args.device) model.to(args.device)
# Evaluation # Evaluation
results = {} results = {}
if args.do_eval and args.local_rank in [-1, 0]: if args.do_eval and args.local_rank in [-1, 0]:
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list( checkpoints = list(
......
...@@ -31,14 +31,8 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Tenso ...@@ -31,14 +31,8 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Tenso
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
from transformers import ( from transformers import WEIGHTS_NAME, AdamW, AutoConfig, AutoTokenizer, get_linear_schedule_with_warmup
WEIGHTS_NAME, from transformers.modeling_auto import AutoModelForMultipleChoice
AdamW,
BertConfig,
BertForMultipleChoice,
BertTokenizer,
get_linear_schedule_with_warmup,
)
try: try:
...@@ -49,12 +43,6 @@ except ImportError: ...@@ -49,12 +43,6 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in [BertConfig]), ())
MODEL_CLASSES = {
"bert": (BertConfig, BertForMultipleChoice, BertTokenizer),
}
class SwagExample(object): class SwagExample(object):
"""A single training/test example for the SWAG dataset.""" """A single training/test example for the SWAG dataset."""
...@@ -492,19 +480,12 @@ def main(): ...@@ -492,19 +480,12 @@ def main():
required=True, required=True,
help="SWAG csv for predictions. E.g., val.csv or test.csv", help="SWAG csv for predictions. E.g., val.csv or test.csv",
) )
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument( parser.add_argument(
"--model_name_or_path", "--model_name_or_path",
default=None, default=None,
type=str, type=str,
required=True, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS), help="Path to pretrained model or model identifier from huggingface.co/models",
) )
parser.add_argument( parser.add_argument(
"--output_dir", "--output_dir",
...@@ -536,9 +517,6 @@ def main(): ...@@ -536,9 +517,6 @@ def main():
parser.add_argument( parser.add_argument(
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step." "--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
) )
parser.add_argument(
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
)
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.") parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument( parser.add_argument(
...@@ -652,13 +630,9 @@ def main(): ...@@ -652,13 +630,9 @@ def main():
if args.local_rank not in [-1, 0]: if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
args.model_type = args.model_type.lower() config = AutoConfig.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,)
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) model = AutoModelForMultipleChoice.from_pretrained(
tokenizer = tokenizer_class.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case
)
model = model_class.from_pretrained(
args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config
) )
...@@ -694,8 +668,8 @@ def main(): ...@@ -694,8 +668,8 @@ def main():
torch.save(args, os.path.join(args.output_dir, "training_args.bin")) torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
# Load a trained model and vocabulary that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir) model = AutoModelForMultipleChoice.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir) tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
model.to(args.device) model.to(args.device)
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
...@@ -718,8 +692,8 @@ def main(): ...@@ -718,8 +692,8 @@ def main():
for checkpoint in checkpoints: for checkpoint in checkpoints:
# Reload the model # Reload the model
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint) model = AutoModelForMultipleChoice.from_pretrained(checkpoint)
tokenizer = tokenizer_class.from_pretrained(checkpoint) tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
# Evaluate # Evaluate
......
...@@ -67,9 +67,6 @@ except ImportError: ...@@ -67,9 +67,6 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum(
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ()
)
MODEL_CLASSES = { MODEL_CLASSES = {
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer), "bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
...@@ -505,7 +502,7 @@ def main(): ...@@ -505,7 +502,7 @@ def main():
default=None, default=None,
type=str, type=str,
required=True, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS), help="Path to pretrained model or model identifier from huggingface.co/models",
) )
parser.add_argument( parser.add_argument(
"--output_dir", "--output_dir",
......
...@@ -19,7 +19,6 @@ and adapts it to the specificities of MaskedBert (`pruning_method`, `mask_init` ...@@ -19,7 +19,6 @@ and adapts it to the specificities of MaskedBert (`pruning_method`, `mask_init`
import logging import logging
from transformers.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
...@@ -31,7 +30,6 @@ class MaskedBertConfig(PretrainedConfig): ...@@ -31,7 +30,6 @@ class MaskedBertConfig(PretrainedConfig):
A class replicating the `~transformers.BertConfig` with additional parameters for pruning/masking configuration. A class replicating the `~transformers.BertConfig` with additional parameters for pruning/masking configuration.
""" """
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "masked_bert" model_type = "masked_bert"
def __init__( def __init__(
......
...@@ -29,12 +29,7 @@ from torch.nn import CrossEntropyLoss, MSELoss ...@@ -29,12 +29,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
from emmental import MaskedBertConfig from emmental import MaskedBertConfig
from emmental.modules import MaskedLinear from emmental.modules import MaskedLinear
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
from transformers.modeling_bert import ( from transformers.modeling_bert import ACT2FN, BertLayerNorm, load_tf_weights_in_bert
ACT2FN,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BertLayerNorm,
load_tf_weights_in_bert,
)
from transformers.modeling_utils import PreTrainedModel, prune_linear_layer from transformers.modeling_utils import PreTrainedModel, prune_linear_layer
...@@ -395,7 +390,6 @@ class MaskedBertPreTrainedModel(PreTrainedModel): ...@@ -395,7 +390,6 @@ class MaskedBertPreTrainedModel(PreTrainedModel):
""" """
config_class = MaskedBertConfig config_class = MaskedBertConfig
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_bert load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert" base_model_prefix = "bert"
......
...@@ -53,8 +53,6 @@ except ImportError: ...@@ -53,8 +53,6 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig,)), (),)
MODEL_CLASSES = { MODEL_CLASSES = {
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer), "bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
"masked_bert": (MaskedBertConfig, MaskedBertForSequenceClassification, BertTokenizer), "masked_bert": (MaskedBertConfig, MaskedBertForSequenceClassification, BertTokenizer),
...@@ -576,7 +574,7 @@ def main(): ...@@ -576,7 +574,7 @@ def main():
default=None, default=None,
type=str, type=str,
required=True, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS), help="Path to pretrained model or model identifier from huggingface.co/models",
) )
parser.add_argument( parser.add_argument(
"--task_name", "--task_name",
......
...@@ -57,8 +57,6 @@ except ImportError: ...@@ -57,8 +57,6 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig,)), (),)
MODEL_CLASSES = { MODEL_CLASSES = {
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer), "bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
"masked_bert": (MaskedBertConfig, MaskedBertForQuestionAnswering, BertTokenizer), "masked_bert": (MaskedBertConfig, MaskedBertForQuestionAnswering, BertTokenizer),
...@@ -673,7 +671,7 @@ def main(): ...@@ -673,7 +671,7 @@ def main():
default=None, default=None,
type=str, type=str,
required=True, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS), help="Path to pretrained model or model identifier from huggingface.co/models",
) )
parser.add_argument( parser.add_argument(
"--output_dir", "--output_dir",
......
...@@ -58,8 +58,6 @@ logger = logging.getLogger(__name__) ...@@ -58,8 +58,6 @@ logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
def set_seed(args): def set_seed(args):
random.seed(args.seed) random.seed(args.seed)
...@@ -491,7 +489,7 @@ def main(): ...@@ -491,7 +489,7 @@ def main():
default=None, default=None,
type=str, type=str,
required=True, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS), help="Path to pretrained model or model identifier from huggingface.co/models",
) )
parser.add_argument( parser.add_argument(
"--output_dir", "--output_dir",
......
...@@ -61,7 +61,6 @@ class BertAbsConfig(PretrainedConfig): ...@@ -61,7 +61,6 @@ class BertAbsConfig(PretrainedConfig):
the decoder. the decoder.
""" """
pretrained_config_archive_map = BERTABS_FINETUNED_CONFIG_MAP
model_type = "bertabs" model_type = "bertabs"
def __init__( def __init__(
......
...@@ -33,14 +33,13 @@ from transformers import BertConfig, BertModel, PreTrainedModel ...@@ -33,14 +33,13 @@ from transformers import BertConfig, BertModel, PreTrainedModel
MAX_SIZE = 5000 MAX_SIZE = 5000
BERTABS_FINETUNED_MODEL_MAP = { BERTABS_FINETUNED_MODEL_ARCHIVE_LIST = [
"bertabs-finetuned-cnndm": "https://cdn.huggingface.co/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization/pytorch_model.bin", "remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization",
} ]
class BertAbsPreTrainedModel(PreTrainedModel): class BertAbsPreTrainedModel(PreTrainedModel):
config_class = BertAbsConfig config_class = BertAbsConfig
pretrained_model_archive_map = BERTABS_FINETUNED_MODEL_MAP
load_tf_weights = False load_tf_weights = False
base_model_prefix = "bert" base_model_prefix = "bert"
......
...@@ -258,7 +258,7 @@ TEST RESULTS {'val_loss': tensor(0.0707), 'precision': 0.852427800698191, 'recal ...@@ -258,7 +258,7 @@ TEST RESULTS {'val_loss': tensor(0.0707), 'precision': 0.852427800698191, 'recal
Based on the script [`run_xnli.py`](https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_xnli.py). Based on the script [`run_xnli.py`](https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_xnli.py).
[XNLI](https://www.nyu.edu/projects/bowman/xnli/) is crowd-sourced dataset based on [MultiNLI](http://www.nyu.edu/projects/bowman/multinli/). It is an evaluation benchmark for cross-lingual text representations. Pairs of text are labeled with textual entailment annotations for 15 different languages (including both high-resource language such as English and low-resource languages such as Swahili). [XNLI](https://www.nyu.edu/projects/bowman/xnli/) is a crowd-sourced dataset based on [MultiNLI](http://www.nyu.edu/projects/bowman/multinli/). It is an evaluation benchmark for cross-lingual text representations. Pairs of text are labeled with textual entailment annotations for 15 different languages (including both high-resource language such as English and low-resource languages such as Swahili).
#### Fine-tuning on XNLI #### Fine-tuning on XNLI
...@@ -273,7 +273,6 @@ on a single tesla V100 16GB. The data for XNLI can be downloaded with the follow ...@@ -273,7 +273,6 @@ on a single tesla V100 16GB. The data for XNLI can be downloaded with the follow
export XNLI_DIR=/path/to/XNLI export XNLI_DIR=/path/to/XNLI
python run_xnli.py \ python run_xnli.py \
--model_type bert \
--model_name_or_path bert-base-multilingual-cased \ --model_name_or_path bert-base-multilingual-cased \
--language de \ --language de \
--train_language en \ --train_language en \
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Finetuning multi-lingual models on XNLI (Bert, DistilBERT, XLM). """ Finetuning multi-lingual models on XNLI (e.g. Bert, DistilBERT, XLM).
Adapted from `examples/text-classification/run_glue.py`""" Adapted from `examples/text-classification/run_glue.py`"""
...@@ -32,15 +32,9 @@ from tqdm import tqdm, trange ...@@ -32,15 +32,9 @@ from tqdm import tqdm, trange
from transformers import ( from transformers import (
WEIGHTS_NAME, WEIGHTS_NAME,
AdamW, AdamW,
BertConfig, AutoConfig,
BertForSequenceClassification, AutoModelForSequenceClassification,
BertTokenizer, AutoTokenizer,
DistilBertConfig,
DistilBertForSequenceClassification,
DistilBertTokenizer,
XLMConfig,
XLMForSequenceClassification,
XLMTokenizer,
get_linear_schedule_with_warmup, get_linear_schedule_with_warmup,
) )
from transformers import glue_convert_examples_to_features as convert_examples_to_features from transformers import glue_convert_examples_to_features as convert_examples_to_features
...@@ -57,16 +51,6 @@ except ImportError: ...@@ -57,16 +51,6 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum(
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, DistilBertConfig, XLMConfig)), ()
)
MODEL_CLASSES = {
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
"xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
"distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
}
def set_seed(args): def set_seed(args):
random.seed(args.seed) random.seed(args.seed)
...@@ -377,19 +361,12 @@ def main(): ...@@ -377,19 +361,12 @@ def main():
required=True, required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.", help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
) )
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument( parser.add_argument(
"--model_name_or_path", "--model_name_or_path",
default=None, default=None,
type=str, type=str,
required=True, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS), help="Path to pretrained model or model identifier from huggingface.co/models",
) )
parser.add_argument( parser.add_argument(
"--language", "--language",
...@@ -421,7 +398,7 @@ def main(): ...@@ -421,7 +398,7 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--cache_dir", "--cache_dir",
default="", default=None,
type=str, type=str,
help="Where do you want to store the pre-trained models downloaded from s3", help="Where do you want to store the pre-trained models downloaded from s3",
) )
...@@ -562,24 +539,23 @@ def main(): ...@@ -562,24 +539,23 @@ def main():
if args.local_rank not in [-1, 0]: if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
args.model_type = args.model_type.lower() config = AutoConfig.from_pretrained(
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path, args.config_name if args.config_name else args.model_name_or_path,
num_labels=num_labels, num_labels=num_labels,
finetuning_task=args.task_name, finetuning_task=args.task_name,
cache_dir=args.cache_dir if args.cache_dir else None, cache_dir=args.cache_dir,
) )
tokenizer = tokenizer_class.from_pretrained( args.model_type = config.model_type
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case, do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None, cache_dir=args.cache_dir,
) )
model = model_class.from_pretrained( model = AutoModelForSequenceClassification.from_pretrained(
args.model_name_or_path, args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path), from_tf=bool(".ckpt" in args.model_name_or_path),
config=config, config=config,
cache_dir=args.cache_dir if args.cache_dir else None, cache_dir=args.cache_dir,
) )
if args.local_rank == 0: if args.local_rank == 0:
...@@ -614,14 +590,13 @@ def main(): ...@@ -614,14 +590,13 @@ def main():
torch.save(args, os.path.join(args.output_dir, "training_args.bin")) torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
# Load a trained model and vocabulary that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir) model = AutoModelForSequenceClassification.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir) tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
model.to(args.device) model.to(args.device)
# Evaluation # Evaluation
results = {} results = {}
if args.do_eval and args.local_rank in [-1, 0]: if args.do_eval and args.local_rank in [-1, 0]:
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list( checkpoints = list(
...@@ -633,7 +608,7 @@ def main(): ...@@ -633,7 +608,7 @@ def main():
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else "" prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
model = model_class.from_pretrained(checkpoint) model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=prefix) result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Fine-tuning the library models for named entity recognition on CoNLL-2003 (Bert or Roberta). """ """ Fine-tuning the library models for named entity recognition on CoNLL-2003. """
import logging import logging
......
...@@ -159,7 +159,6 @@ if is_torch_available(): ...@@ -159,7 +159,6 @@ if is_torch_available():
AutoModelWithLMHead, AutoModelWithLMHead,
AutoModelForTokenClassification, AutoModelForTokenClassification,
AutoModelForMultipleChoice, AutoModelForMultipleChoice,
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
MODEL_MAPPING, MODEL_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING, MODEL_FOR_PRETRAINING_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING, MODEL_WITH_LM_HEAD_MAPPING,
...@@ -180,7 +179,7 @@ if is_torch_available(): ...@@ -180,7 +179,7 @@ if is_torch_available():
BertForTokenClassification, BertForTokenClassification,
BertForQuestionAnswering, BertForQuestionAnswering,
load_tf_weights_in_bert, load_tf_weights_in_bert,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
BertLayer, BertLayer,
) )
from .modeling_openai import ( from .modeling_openai import (
...@@ -189,7 +188,7 @@ if is_torch_available(): ...@@ -189,7 +188,7 @@ if is_torch_available():
OpenAIGPTLMHeadModel, OpenAIGPTLMHeadModel,
OpenAIGPTDoubleHeadsModel, OpenAIGPTDoubleHeadsModel,
load_tf_weights_in_openai_gpt, load_tf_weights_in_openai_gpt,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_transfo_xl import ( from .modeling_transfo_xl import (
TransfoXLPreTrainedModel, TransfoXLPreTrainedModel,
...@@ -197,7 +196,7 @@ if is_torch_available(): ...@@ -197,7 +196,7 @@ if is_torch_available():
TransfoXLLMHeadModel, TransfoXLLMHeadModel,
AdaptiveEmbedding, AdaptiveEmbedding,
load_tf_weights_in_transfo_xl, load_tf_weights_in_transfo_xl,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_gpt2 import ( from .modeling_gpt2 import (
GPT2PreTrainedModel, GPT2PreTrainedModel,
...@@ -205,9 +204,9 @@ if is_torch_available(): ...@@ -205,9 +204,9 @@ if is_torch_available():
GPT2LMHeadModel, GPT2LMHeadModel,
GPT2DoubleHeadsModel, GPT2DoubleHeadsModel,
load_tf_weights_in_gpt2, load_tf_weights_in_gpt2,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_ctrl import CTRLPreTrainedModel, CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_ctrl import CTRLPreTrainedModel, CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_LIST
from .modeling_xlnet import ( from .modeling_xlnet import (
XLNetPreTrainedModel, XLNetPreTrainedModel,
XLNetModel, XLNetModel,
...@@ -218,7 +217,7 @@ if is_torch_available(): ...@@ -218,7 +217,7 @@ if is_torch_available():
XLNetForQuestionAnsweringSimple, XLNetForQuestionAnsweringSimple,
XLNetForQuestionAnswering, XLNetForQuestionAnswering,
load_tf_weights_in_xlnet, load_tf_weights_in_xlnet,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_xlm import ( from .modeling_xlm import (
XLMPreTrainedModel, XLMPreTrainedModel,
...@@ -228,13 +227,13 @@ if is_torch_available(): ...@@ -228,13 +227,13 @@ if is_torch_available():
XLMForTokenClassification, XLMForTokenClassification,
XLMForQuestionAnswering, XLMForQuestionAnswering,
XLMForQuestionAnsweringSimple, XLMForQuestionAnsweringSimple,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_bart import ( from .modeling_bart import (
BartForSequenceClassification, BartForSequenceClassification,
BartModel, BartModel,
BartForConditionalGeneration, BartForConditionalGeneration,
BART_PRETRAINED_MODEL_ARCHIVE_MAP, BART_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_marian import MarianMTModel from .modeling_marian import MarianMTModel
from .tokenization_marian import MarianTokenizer from .tokenization_marian import MarianTokenizer
...@@ -245,7 +244,7 @@ if is_torch_available(): ...@@ -245,7 +244,7 @@ if is_torch_available():
RobertaForMultipleChoice, RobertaForMultipleChoice,
RobertaForTokenClassification, RobertaForTokenClassification,
RobertaForQuestionAnswering, RobertaForQuestionAnswering,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_distilbert import ( from .modeling_distilbert import (
DistilBertPreTrainedModel, DistilBertPreTrainedModel,
...@@ -254,7 +253,7 @@ if is_torch_available(): ...@@ -254,7 +253,7 @@ if is_torch_available():
DistilBertForSequenceClassification, DistilBertForSequenceClassification,
DistilBertForQuestionAnswering, DistilBertForQuestionAnswering,
DistilBertForTokenClassification, DistilBertForTokenClassification,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_camembert import ( from .modeling_camembert import (
CamembertForMaskedLM, CamembertForMaskedLM,
...@@ -263,7 +262,7 @@ if is_torch_available(): ...@@ -263,7 +262,7 @@ if is_torch_available():
CamembertForMultipleChoice, CamembertForMultipleChoice,
CamembertForTokenClassification, CamembertForTokenClassification,
CamembertForQuestionAnswering, CamembertForQuestionAnswering,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP, CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_encoder_decoder import EncoderDecoderModel from .modeling_encoder_decoder import EncoderDecoderModel
from .modeling_t5 import ( from .modeling_t5 import (
...@@ -271,7 +270,7 @@ if is_torch_available(): ...@@ -271,7 +270,7 @@ if is_torch_available():
T5Model, T5Model,
T5ForConditionalGeneration, T5ForConditionalGeneration,
load_tf_weights_in_t5, load_tf_weights_in_t5,
T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_albert import ( from .modeling_albert import (
AlbertPreTrainedModel, AlbertPreTrainedModel,
...@@ -282,7 +281,7 @@ if is_torch_available(): ...@@ -282,7 +281,7 @@ if is_torch_available():
AlbertForQuestionAnswering, AlbertForQuestionAnswering,
AlbertForTokenClassification, AlbertForTokenClassification,
load_tf_weights_in_albert, load_tf_weights_in_albert,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_xlm_roberta import ( from .modeling_xlm_roberta import (
XLMRobertaForMaskedLM, XLMRobertaForMaskedLM,
...@@ -290,7 +289,7 @@ if is_torch_available(): ...@@ -290,7 +289,7 @@ if is_torch_available():
XLMRobertaForMultipleChoice, XLMRobertaForMultipleChoice,
XLMRobertaForSequenceClassification, XLMRobertaForSequenceClassification,
XLMRobertaForTokenClassification, XLMRobertaForTokenClassification,
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_mmbt import ModalEmbeddings, MMBTModel, MMBTForClassification from .modeling_mmbt import ModalEmbeddings, MMBTModel, MMBTForClassification
...@@ -300,7 +299,7 @@ if is_torch_available(): ...@@ -300,7 +299,7 @@ if is_torch_available():
FlaubertForSequenceClassification, FlaubertForSequenceClassification,
FlaubertForQuestionAnswering, FlaubertForQuestionAnswering,
FlaubertForQuestionAnsweringSimple, FlaubertForQuestionAnsweringSimple,
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP, FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_electra import ( from .modeling_electra import (
...@@ -311,7 +310,7 @@ if is_torch_available(): ...@@ -311,7 +310,7 @@ if is_torch_available():
ElectraForSequenceClassification, ElectraForSequenceClassification,
ElectraModel, ElectraModel,
load_tf_weights_in_electra, load_tf_weights_in_electra,
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP, ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_reformer import ( from .modeling_reformer import (
...@@ -319,7 +318,7 @@ if is_torch_available(): ...@@ -319,7 +318,7 @@ if is_torch_available():
ReformerLayer, ReformerLayer,
ReformerModel, ReformerModel,
ReformerModelWithLMHead, ReformerModelWithLMHead,
REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_longformer import ( from .modeling_longformer import (
...@@ -329,7 +328,7 @@ if is_torch_available(): ...@@ -329,7 +328,7 @@ if is_torch_available():
LongformerForMultipleChoice, LongformerForMultipleChoice,
LongformerForTokenClassification, LongformerForTokenClassification,
LongformerForQuestionAnswering, LongformerForQuestionAnswering,
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
# Optimization # Optimization
...@@ -367,7 +366,6 @@ if is_tf_available(): ...@@ -367,7 +366,6 @@ if is_tf_available():
TFAutoModelForQuestionAnswering, TFAutoModelForQuestionAnswering,
TFAutoModelWithLMHead, TFAutoModelWithLMHead,
TFAutoModelForTokenClassification, TFAutoModelForTokenClassification,
TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_MODEL_MAPPING, TF_MODEL_MAPPING,
TF_MODEL_FOR_PRETRAINING_MAPPING, TF_MODEL_FOR_PRETRAINING_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING,
...@@ -388,7 +386,7 @@ if is_tf_available(): ...@@ -388,7 +386,7 @@ if is_tf_available():
TFBertForMultipleChoice, TFBertForMultipleChoice,
TFBertForTokenClassification, TFBertForTokenClassification,
TFBertForQuestionAnswering, TFBertForQuestionAnswering,
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP, TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_tf_gpt2 import ( from .modeling_tf_gpt2 import (
...@@ -397,7 +395,7 @@ if is_tf_available(): ...@@ -397,7 +395,7 @@ if is_tf_available():
TFGPT2Model, TFGPT2Model,
TFGPT2LMHeadModel, TFGPT2LMHeadModel,
TFGPT2DoubleHeadsModel, TFGPT2DoubleHeadsModel,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_tf_openai import ( from .modeling_tf_openai import (
...@@ -406,7 +404,7 @@ if is_tf_available(): ...@@ -406,7 +404,7 @@ if is_tf_available():
TFOpenAIGPTModel, TFOpenAIGPTModel,
TFOpenAIGPTLMHeadModel, TFOpenAIGPTLMHeadModel,
TFOpenAIGPTDoubleHeadsModel, TFOpenAIGPTDoubleHeadsModel,
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_tf_transfo_xl import ( from .modeling_tf_transfo_xl import (
...@@ -414,7 +412,7 @@ if is_tf_available(): ...@@ -414,7 +412,7 @@ if is_tf_available():
TFTransfoXLMainLayer, TFTransfoXLMainLayer,
TFTransfoXLModel, TFTransfoXLModel,
TFTransfoXLLMHeadModel, TFTransfoXLLMHeadModel,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFAdaptiveEmbedding, TFAdaptiveEmbedding,
) )
...@@ -426,7 +424,7 @@ if is_tf_available(): ...@@ -426,7 +424,7 @@ if is_tf_available():
TFXLNetForSequenceClassification, TFXLNetForSequenceClassification,
TFXLNetForTokenClassification, TFXLNetForTokenClassification,
TFXLNetForQuestionAnsweringSimple, TFXLNetForQuestionAnsweringSimple,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_tf_xlm import ( from .modeling_tf_xlm import (
...@@ -436,7 +434,7 @@ if is_tf_available(): ...@@ -436,7 +434,7 @@ if is_tf_available():
TFXLMWithLMHeadModel, TFXLMWithLMHeadModel,
TFXLMForSequenceClassification, TFXLMForSequenceClassification,
TFXLMForQuestionAnsweringSimple, TFXLMForQuestionAnsweringSimple,
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP, TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_tf_xlm_roberta import ( from .modeling_tf_xlm_roberta import (
...@@ -444,7 +442,7 @@ if is_tf_available(): ...@@ -444,7 +442,7 @@ if is_tf_available():
TFXLMRobertaModel, TFXLMRobertaModel,
TFXLMRobertaForSequenceClassification, TFXLMRobertaForSequenceClassification,
TFXLMRobertaForTokenClassification, TFXLMRobertaForTokenClassification,
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_tf_roberta import ( from .modeling_tf_roberta import (
...@@ -455,7 +453,7 @@ if is_tf_available(): ...@@ -455,7 +453,7 @@ if is_tf_available():
TFRobertaForSequenceClassification, TFRobertaForSequenceClassification,
TFRobertaForTokenClassification, TFRobertaForTokenClassification,
TFRobertaForQuestionAnswering, TFRobertaForQuestionAnswering,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_tf_camembert import ( from .modeling_tf_camembert import (
...@@ -463,14 +461,14 @@ if is_tf_available(): ...@@ -463,14 +461,14 @@ if is_tf_available():
TFCamembertForMaskedLM, TFCamembertForMaskedLM,
TFCamembertForSequenceClassification, TFCamembertForSequenceClassification,
TFCamembertForTokenClassification, TFCamembertForTokenClassification,
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP, TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_tf_flaubert import ( from .modeling_tf_flaubert import (
TFFlaubertModel, TFFlaubertModel,
TFFlaubertWithLMHeadModel, TFFlaubertWithLMHeadModel,
TFFlaubertForSequenceClassification, TFFlaubertForSequenceClassification,
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP, TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_tf_distilbert import ( from .modeling_tf_distilbert import (
...@@ -481,14 +479,14 @@ if is_tf_available(): ...@@ -481,14 +479,14 @@ if is_tf_available():
TFDistilBertForSequenceClassification, TFDistilBertForSequenceClassification,
TFDistilBertForTokenClassification, TFDistilBertForTokenClassification,
TFDistilBertForQuestionAnswering, TFDistilBertForQuestionAnswering,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_tf_ctrl import ( from .modeling_tf_ctrl import (
TFCTRLPreTrainedModel, TFCTRLPreTrainedModel,
TFCTRLModel, TFCTRLModel,
TFCTRLLMHeadModel, TFCTRLLMHeadModel,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_tf_albert import ( from .modeling_tf_albert import (
...@@ -500,14 +498,14 @@ if is_tf_available(): ...@@ -500,14 +498,14 @@ if is_tf_available():
TFAlbertForMultipleChoice, TFAlbertForMultipleChoice,
TFAlbertForSequenceClassification, TFAlbertForSequenceClassification,
TFAlbertForQuestionAnswering, TFAlbertForQuestionAnswering,
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_tf_t5 import ( from .modeling_tf_t5 import (
TFT5PreTrainedModel, TFT5PreTrainedModel,
TFT5Model, TFT5Model,
TFT5ForConditionalGeneration, TFT5ForConditionalGeneration,
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP, TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
from .modeling_tf_electra import ( from .modeling_tf_electra import (
...@@ -516,7 +514,7 @@ if is_tf_available(): ...@@ -516,7 +514,7 @@ if is_tf_available():
TFElectraForPreTraining, TFElectraForPreTraining,
TFElectraForMaskedLM, TFElectraForMaskedLM,
TFElectraForTokenClassification, TFElectraForTokenClassification,
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP, TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
) )
# Optimization # Optimization
......
...@@ -32,7 +32,7 @@ ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -32,7 +32,7 @@ ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
class AlbertConfig(PretrainedConfig): class AlbertConfig(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of an :class:`~transformers.AlbertModel`. This is the configuration class to store the configuration of a :class:`~transformers.AlbertModel`.
It is used to instantiate an ALBERT model according to the specified arguments, defining the model It is used to instantiate an ALBERT model according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
the ALBERT `xxlarge <https://huggingface.co/albert-xxlarge-v2>`__ architecture. the ALBERT `xxlarge <https://huggingface.co/albert-xxlarge-v2>`__ architecture.
...@@ -97,13 +97,8 @@ class AlbertConfig(PretrainedConfig): ...@@ -97,13 +97,8 @@ class AlbertConfig(PretrainedConfig):
# Accessing the model configuration # Accessing the model configuration
configuration = model.config configuration = model.config
Attributes:
pretrained_config_archive_map (Dict[str, str]):
A dictionary containing all the available pre-trained checkpoints.
""" """
pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "albert" model_type = "albert"
def __init__( def __init__(
......
...@@ -113,12 +113,12 @@ class AutoConfig: ...@@ -113,12 +113,12 @@ class AutoConfig:
) )
@classmethod @classmethod
def for_model(cls, model_type, *args, **kwargs): def for_model(cls, model_type: str, *args, **kwargs):
for pattern, config_class in CONFIG_MAPPING.items(): if model_type in CONFIG_MAPPING:
if pattern in model_type: config_class = CONFIG_MAPPING[model_type]
return config_class(*args, **kwargs) return config_class(*args, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contain one of {}".format( "Unrecognized model identifier: {}. Should contain one of {}".format(
model_type, ", ".join(CONFIG_MAPPING.keys()) model_type, ", ".join(CONFIG_MAPPING.keys())
) )
) )
...@@ -130,24 +130,24 @@ class AutoConfig: ...@@ -130,24 +130,24 @@ class AutoConfig:
The configuration class to instantiate is selected The configuration class to instantiate is selected
based on the `model_type` property of the config object, or when it's missing, based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string. falling back to using pattern matching on the `pretrained_model_name_or_path` string:
- contains `t5`: :class:`~transformers.T5Config` (T5 model) - `t5`: :class:`~transformers.T5Config` (T5 model)
- contains `distilbert`: :class:`~transformers.DistilBertConfig` (DistilBERT model) - `distilbert`: :class:`~transformers.DistilBertConfig` (DistilBERT model)
- contains `albert`: :class:`~transformers.AlbertConfig` (ALBERT model) - `albert`: :class:`~transformers.AlbertConfig` (ALBERT model)
- contains `camembert`: :class:`~transformers.CamembertConfig` (CamemBERT model) - `camembert`: :class:`~transformers.CamembertConfig` (CamemBERT model)
- contains `xlm-roberta`: :class:`~transformers.XLMRobertaConfig` (XLM-RoBERTa model) - `xlm-roberta`: :class:`~transformers.XLMRobertaConfig` (XLM-RoBERTa model)
- contains `longformer`: :class:`~transformers.LongformerConfig` (Longformer model) - `longformer`: :class:`~transformers.LongformerConfig` (Longformer model)
- contains `roberta`: :class:`~transformers.RobertaConfig` (RoBERTa model) - `roberta`: :class:`~transformers.RobertaConfig` (RoBERTa model)
- contains `reformer`: :class:`~transformers.ReformerConfig` (Reformer model) - `reformer`: :class:`~transformers.ReformerConfig` (Reformer model)
- contains `bert`: :class:`~transformers.BertConfig` (Bert model) - `bert`: :class:`~transformers.BertConfig` (Bert model)
- contains `openai-gpt`: :class:`~transformers.OpenAIGPTConfig` (OpenAI GPT model) - `openai-gpt`: :class:`~transformers.OpenAIGPTConfig` (OpenAI GPT model)
- contains `gpt2`: :class:`~transformers.GPT2Config` (OpenAI GPT-2 model) - `gpt2`: :class:`~transformers.GPT2Config` (OpenAI GPT-2 model)
- contains `transfo-xl`: :class:`~transformers.TransfoXLConfig` (Transformer-XL model) - `transfo-xl`: :class:`~transformers.TransfoXLConfig` (Transformer-XL model)
- contains `xlnet`: :class:`~transformers.XLNetConfig` (XLNet model) - `xlnet`: :class:`~transformers.XLNetConfig` (XLNet model)
- contains `xlm`: :class:`~transformers.XLMConfig` (XLM model) - `xlm`: :class:`~transformers.XLMConfig` (XLM model)
- contains `ctrl` : :class:`~transformers.CTRLConfig` (CTRL model) - `ctrl` : :class:`~transformers.CTRLConfig` (CTRL model)
- contains `flaubert` : :class:`~transformers.FlaubertConfig` (Flaubert model) - `flaubert` : :class:`~transformers.FlaubertConfig` (Flaubert model)
- contains `electra` : :class:`~transformers.ElectraConfig` (ELECTRA model) - `electra` : :class:`~transformers.ElectraConfig` (ELECTRA model)
Args: Args:
pretrained_model_name_or_path (:obj:`string`): pretrained_model_name_or_path (:obj:`string`):
...@@ -193,9 +193,7 @@ class AutoConfig: ...@@ -193,9 +193,7 @@ class AutoConfig:
assert unused_kwargs == {'foo': False} assert unused_kwargs == {'foo': False}
""" """
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
pretrained_model_name_or_path, pretrained_config_archive_map=ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, **kwargs
)
if "model_type" in config_dict: if "model_type" in config_dict:
config_class = CONFIG_MAPPING[config_dict["model_type"]] config_class = CONFIG_MAPPING[config_dict["model_type"]]
......
...@@ -23,11 +23,11 @@ from .configuration_utils import PretrainedConfig ...@@ -23,11 +23,11 @@ from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = { BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json", "facebook/bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json",
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json", "facebook/bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json", "facebook/bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json", "facebook/bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
"mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json", "facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
} }
...@@ -36,7 +36,6 @@ class BartConfig(PretrainedConfig): ...@@ -36,7 +36,6 @@ class BartConfig(PretrainedConfig):
Configuration class for Bart. Parameters are renamed from the fairseq implementation Configuration class for Bart. Parameters are renamed from the fairseq implementation
""" """
model_type = "bart" model_type = "bart"
pretrained_config_archive_map = BART_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__( def __init__(
self, self,
......
...@@ -39,13 +39,14 @@ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -39,13 +39,14 @@ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
"bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json", "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
"bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json", "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
"bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json", "cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json",
"bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json", "cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json",
"bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json", "cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json",
"bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json", "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json",
"bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json", "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
"bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json", "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
"bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json", "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json",
# See all BERT models at https://huggingface.co/models?filter=bert
} }
...@@ -102,12 +103,7 @@ class BertConfig(PretrainedConfig): ...@@ -102,12 +103,7 @@ class BertConfig(PretrainedConfig):
# Accessing the model configuration # Accessing the model configuration
configuration = model.config configuration = model.config
Attributes:
pretrained_config_archive_map (Dict[str, str]):
A dictionary containing all the available pre-trained checkpoints.
""" """
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "bert" model_type = "bert"
def __init__( def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment