Unverified Commit d4c2cb40 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Kill model archive maps (#4636)

* Kill model archive maps

* Fixup

* Also kill model_archive_map for MaskedBertPreTrainedModel

* Unhook config_archive_map

* Tokenizers: align with model id changes

* make style && make quality

* Fix CI
parent 47a551d1
......@@ -63,33 +63,33 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
| | | | Trained on uncased German text by DBMDZ |
| | | (see `details on dbmdz repository <https://github.com/dbmdz/german-bert>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-japanese`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | ``cl-tohoku/bert-base-japanese`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on Japanese text. Text is tokenized with MeCab and WordPiece. |
| | | | `MeCab <https://taku910.github.io/mecab/>`__ is required for tokenization. |
| | | (see `details on cl-tohoku repository <https://github.com/cl-tohoku/bert-japanese>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-japanese-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | ``cl-tohoku/bert-base-japanese-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on Japanese text using Whole-Word-Masking. Text is tokenized with MeCab and WordPiece. |
| | | | `MeCab <https://taku910.github.io/mecab/>`__ is required for tokenization. |
| | | (see `details on cl-tohoku repository <https://github.com/cl-tohoku/bert-japanese>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-japanese-char`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | ``cl-tohoku/bert-base-japanese-char`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on Japanese text. Text is tokenized into characters. |
| | | (see `details on cl-tohoku repository <https://github.com/cl-tohoku/bert-japanese>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-japanese-char-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | ``cl-tohoku/bert-base-japanese-char-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on Japanese text using Whole-Word-Masking. Text is tokenized into characters. |
| | | (see `details on cl-tohoku repository <https://github.com/cl-tohoku/bert-japanese>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-finnish-cased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | ``TurkuNLP/bert-base-finnish-cased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on cased Finnish text. |
| | | (see `details on turkunlp.org <http://turkunlp.org/FinBERT/>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-finnish-uncased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | ``TurkuNLP/bert-base-finnish-uncased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on uncased Finnish text. |
| | | (see `details on turkunlp.org <http://turkunlp.org/FinBERT/>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-dutch-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | ``wietsedv/bert-base-dutch-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on cased Dutch text. |
| | | (see `details on wietsedv repository <https://github.com/wietsedv/bertje/>`__). |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
......@@ -259,32 +259,32 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
| | ``xlm-roberta-large`` | | ~355M parameters with 24-layers, 1027-hidden-state, 4096 feed-forward hidden-state, 16-heads, |
| | | | Trained on 2.5 TB of newly created clean CommonCrawl data in 100 languages |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| FlauBERT | ``flaubert-small-cased`` | | 6-layer, 512-hidden, 8-heads, 54M parameters |
| FlauBERT | ``flaubert/flaubert_small_cased`` | | 6-layer, 512-hidden, 8-heads, 54M parameters |
| | | | FlauBERT small architecture |
| | | (see `details <https://github.com/getalp/Flaubert>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``flaubert-base-uncased`` | | 12-layer, 768-hidden, 12-heads, 137M parameters |
| | ``flaubert/flaubert_base_uncased`` | | 12-layer, 768-hidden, 12-heads, 137M parameters |
| | | | FlauBERT base architecture with uncased vocabulary |
| | | (see `details <https://github.com/getalp/Flaubert>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``flaubert-base-cased`` | | 12-layer, 768-hidden, 12-heads, 138M parameters |
| | ``flaubert/flaubert_base_cased`` | | 12-layer, 768-hidden, 12-heads, 138M parameters |
| | | | FlauBERT base architecture with cased vocabulary |
| | | (see `details <https://github.com/getalp/Flaubert>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``flaubert-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 373M parameters |
| | ``flaubert/flaubert_large_cased`` | | 24-layer, 1024-hidden, 16-heads, 373M parameters |
| | | | FlauBERT large architecture |
| | | (see `details <https://github.com/getalp/Flaubert>`__) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Bart | ``bart-large`` | | 24-layer, 1024-hidden, 16-heads, 406M parameters |
| Bart | ``facebook/bart-large`` | | 24-layer, 1024-hidden, 16-heads, 406M parameters |
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters |
| | ``facebook/bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters |
| | | | bart-large base architecture with a classification head, finetuned on MNLI |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bart-large-cnn`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters (same as base) |
| | ``facebook/bart-large-cnn`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters (same as base) |
| | | | bart-large base architecture finetuned on cnn summarization task |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``mbart-large-en-ro`` | | 12-layer, 1024-hidden, 16-heads, 880M parameters |
| | ``facebook/mbart-large-en-ro`` | | 12-layer, 1024-hidden, 16-heads, 880M parameters |
| | | | bart-large architecture pretrained on cc25 multilingual data , finetuned on WMT english romanian translation. |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| DialoGPT | ``DialoGPT-small`` | | 12-layer, 768-hidden, 12-heads, 124M parameters |
......@@ -305,9 +305,9 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
| MarianMT | ``Helsinki-NLP/opus-mt-{src}-{tgt}`` | | 12-layer, 512-hidden, 8-heads, ~74M parameter Machine translation models. Parameter counts vary depending on vocab size. |
| | | | (see `model list <https://huggingface.co/Helsinki-NLP>`_) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Longformer | ``longformer-base-4096`` | | 12-layer, 768-hidden, 12-heads, ~149M parameters |
| Longformer | ``allenai/longformer-base-4096`` | | 12-layer, 768-hidden, 12-heads, ~149M parameters |
| | | | Starting from RoBERTa-base checkpoint, trained on documents of max length 4,096 |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``longformer-large-4096`` | | 24-layer, 1024-hidden, 16-heads, ~435M parameters |
| | ``allenai/longformer-large-4096`` | | 24-layer, 1024-hidden, 16-heads, ~435M parameters |
| | | | Starting from RoBERTa-large checkpoint, trained on documents of max length 4,096 |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
......@@ -65,13 +65,6 @@ except ImportError:
logger = logging.getLogger(__name__)
ALL_MODELS = sum(
(
tuple(conf.pretrained_config_archive_map.keys())
for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig)
),
(),
)
MODEL_CLASSES = {
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
......@@ -389,7 +382,7 @@ def main():
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
help="Path to pretrained model or model identifier from huggingface.co/models",
)
parser.add_argument(
"--task_name",
......
......@@ -34,26 +34,11 @@ from tqdm import tqdm, trange
from transformers import (
WEIGHTS_NAME,
AdamW,
AlbertConfig,
AlbertModel,
AlbertTokenizer,
BertConfig,
BertModel,
BertTokenizer,
DistilBertConfig,
DistilBertModel,
DistilBertTokenizer,
AutoConfig,
AutoModel,
AutoTokenizer,
MMBTConfig,
MMBTForClassification,
RobertaConfig,
RobertaModel,
RobertaTokenizer,
XLMConfig,
XLMModel,
XLMTokenizer,
XLNetConfig,
XLNetModel,
XLNetTokenizer,
get_linear_schedule_with_warmup,
)
from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_image_transforms, get_mmimdb_labels
......@@ -67,23 +52,6 @@ except ImportError:
logger = logging.getLogger(__name__)
ALL_MODELS = sum(
(
tuple(conf.pretrained_config_archive_map.keys())
for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig)
),
(),
)
MODEL_CLASSES = {
"bert": (BertConfig, BertModel, BertTokenizer),
"xlnet": (XLNetConfig, XLNetModel, XLNetTokenizer),
"xlm": (XLMConfig, XLMModel, XLMTokenizer),
"roberta": (RobertaConfig, RobertaModel, RobertaTokenizer),
"distilbert": (DistilBertConfig, DistilBertModel, DistilBertTokenizer),
"albert": (AlbertConfig, AlbertModel, AlbertTokenizer),
}
def set_seed(args):
random.seed(args.seed)
......@@ -351,19 +319,12 @@ def main():
required=True,
help="The input data dir. Should contain the .jsonl files for MMIMDB.",
)
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
help="Path to pretrained model or model identifier from huggingface.co/models",
)
parser.add_argument(
"--output_dir",
......@@ -385,7 +346,7 @@ def main():
)
parser.add_argument(
"--cache_dir",
default="",
default=None,
type=str,
help="Where do you want to store the pre-trained models downloaded from s3",
)
......@@ -526,18 +487,14 @@ def main():
# Setup model
labels = get_mmimdb_labels()
num_labels = len(labels)
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
transformer_config = config_class.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path
)
tokenizer = tokenizer_class.from_pretrained(
transformer_config = AutoConfig.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None,
cache_dir=args.cache_dir,
)
transformer = model_class.from_pretrained(
args.model_name_or_path, config=transformer_config, cache_dir=args.cache_dir if args.cache_dir else None
transformer = AutoModel.from_pretrained(
args.model_name_or_path, config=transformer_config, cache_dir=args.cache_dir
)
img_encoder = ImageEncoder(args)
config = MMBTConfig(transformer_config, num_labels=num_labels)
......@@ -583,13 +540,12 @@ def main():
# Load a trained model and vocabulary that you have fine-tuned
model = MMBTForClassification(config, transformer, img_encoder)
model.load_state_dict(torch.load(os.path.join(args.output_dir, WEIGHTS_NAME)))
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
model.to(args.device)
# Evaluation
results = {}
if args.do_eval and args.local_rank in [-1, 0]:
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
......
......@@ -31,14 +31,8 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Tenso
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from transformers import (
WEIGHTS_NAME,
AdamW,
BertConfig,
BertForMultipleChoice,
BertTokenizer,
get_linear_schedule_with_warmup,
)
from transformers import WEIGHTS_NAME, AdamW, AutoConfig, AutoTokenizer, get_linear_schedule_with_warmup
from transformers.modeling_auto import AutoModelForMultipleChoice
try:
......@@ -49,12 +43,6 @@ except ImportError:
logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in [BertConfig]), ())
MODEL_CLASSES = {
"bert": (BertConfig, BertForMultipleChoice, BertTokenizer),
}
class SwagExample(object):
"""A single training/test example for the SWAG dataset."""
......@@ -492,19 +480,12 @@ def main():
required=True,
help="SWAG csv for predictions. E.g., val.csv or test.csv",
)
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
help="Path to pretrained model or model identifier from huggingface.co/models",
)
parser.add_argument(
"--output_dir",
......@@ -536,9 +517,6 @@ def main():
parser.add_argument(
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
)
parser.add_argument(
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
)
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument(
......@@ -652,13 +630,9 @@ def main():
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case
)
model = model_class.from_pretrained(
config = AutoConfig.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,)
model = AutoModelForMultipleChoice.from_pretrained(
args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config
)
......@@ -694,8 +668,8 @@ def main():
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
model = AutoModelForMultipleChoice.from_pretrained(args.output_dir)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
model.to(args.device)
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
......@@ -718,8 +692,8 @@ def main():
for checkpoint in checkpoints:
# Reload the model
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint)
tokenizer = tokenizer_class.from_pretrained(checkpoint)
model = AutoModelForMultipleChoice.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model.to(args.device)
# Evaluate
......
......@@ -67,9 +67,6 @@ except ImportError:
logger = logging.getLogger(__name__)
ALL_MODELS = sum(
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ()
)
MODEL_CLASSES = {
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
......@@ -505,7 +502,7 @@ def main():
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
help="Path to pretrained model or model identifier from huggingface.co/models",
)
parser.add_argument(
"--output_dir",
......
......@@ -19,7 +19,6 @@ and adapts it to the specificities of MaskedBert (`pruning_method`, `mask_init`
import logging
from transformers.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from transformers.configuration_utils import PretrainedConfig
......@@ -31,7 +30,6 @@ class MaskedBertConfig(PretrainedConfig):
A class replicating the `~transformers.BertConfig` with additional parameters for pruning/masking configuration.
"""
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "masked_bert"
def __init__(
......
......@@ -29,12 +29,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
from emmental import MaskedBertConfig
from emmental.modules import MaskedLinear
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
from transformers.modeling_bert import (
ACT2FN,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BertLayerNorm,
load_tf_weights_in_bert,
)
from transformers.modeling_bert import ACT2FN, BertLayerNorm, load_tf_weights_in_bert
from transformers.modeling_utils import PreTrainedModel, prune_linear_layer
......@@ -395,7 +390,6 @@ class MaskedBertPreTrainedModel(PreTrainedModel):
"""
config_class = MaskedBertConfig
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
......
......@@ -53,8 +53,6 @@ except ImportError:
logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig,)), (),)
MODEL_CLASSES = {
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
"masked_bert": (MaskedBertConfig, MaskedBertForSequenceClassification, BertTokenizer),
......@@ -576,7 +574,7 @@ def main():
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
help="Path to pretrained model or model identifier from huggingface.co/models",
)
parser.add_argument(
"--task_name",
......
......@@ -57,8 +57,6 @@ except ImportError:
logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig,)), (),)
MODEL_CLASSES = {
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
"masked_bert": (MaskedBertConfig, MaskedBertForQuestionAnswering, BertTokenizer),
......@@ -673,7 +671,7 @@ def main():
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
help="Path to pretrained model or model identifier from huggingface.co/models",
)
parser.add_argument(
"--output_dir",
......
......@@ -58,8 +58,6 @@ logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
def set_seed(args):
random.seed(args.seed)
......@@ -491,7 +489,7 @@ def main():
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
help="Path to pretrained model or model identifier from huggingface.co/models",
)
parser.add_argument(
"--output_dir",
......
......@@ -61,7 +61,6 @@ class BertAbsConfig(PretrainedConfig):
the decoder.
"""
pretrained_config_archive_map = BERTABS_FINETUNED_CONFIG_MAP
model_type = "bertabs"
def __init__(
......
......@@ -33,14 +33,13 @@ from transformers import BertConfig, BertModel, PreTrainedModel
MAX_SIZE = 5000
BERTABS_FINETUNED_MODEL_MAP = {
"bertabs-finetuned-cnndm": "https://cdn.huggingface.co/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization/pytorch_model.bin",
}
BERTABS_FINETUNED_MODEL_ARCHIVE_LIST = [
"remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization",
]
class BertAbsPreTrainedModel(PreTrainedModel):
config_class = BertAbsConfig
pretrained_model_archive_map = BERTABS_FINETUNED_MODEL_MAP
load_tf_weights = False
base_model_prefix = "bert"
......
......@@ -258,7 +258,7 @@ TEST RESULTS {'val_loss': tensor(0.0707), 'precision': 0.852427800698191, 'recal
Based on the script [`run_xnli.py`](https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_xnli.py).
[XNLI](https://www.nyu.edu/projects/bowman/xnli/) is crowd-sourced dataset based on [MultiNLI](http://www.nyu.edu/projects/bowman/multinli/). It is an evaluation benchmark for cross-lingual text representations. Pairs of text are labeled with textual entailment annotations for 15 different languages (including both high-resource language such as English and low-resource languages such as Swahili).
[XNLI](https://www.nyu.edu/projects/bowman/xnli/) is a crowd-sourced dataset based on [MultiNLI](http://www.nyu.edu/projects/bowman/multinli/). It is an evaluation benchmark for cross-lingual text representations. Pairs of text are labeled with textual entailment annotations for 15 different languages (including both high-resource language such as English and low-resource languages such as Swahili).
#### Fine-tuning on XNLI
......@@ -273,7 +273,6 @@ on a single tesla V100 16GB. The data for XNLI can be downloaded with the follow
export XNLI_DIR=/path/to/XNLI
python run_xnli.py \
--model_type bert \
--model_name_or_path bert-base-multilingual-cased \
--language de \
--train_language en \
......
......@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning multi-lingual models on XNLI (Bert, DistilBERT, XLM).
""" Finetuning multi-lingual models on XNLI (e.g. Bert, DistilBERT, XLM).
Adapted from `examples/text-classification/run_glue.py`"""
......@@ -32,15 +32,9 @@ from tqdm import tqdm, trange
from transformers import (
WEIGHTS_NAME,
AdamW,
BertConfig,
BertForSequenceClassification,
BertTokenizer,
DistilBertConfig,
DistilBertForSequenceClassification,
DistilBertTokenizer,
XLMConfig,
XLMForSequenceClassification,
XLMTokenizer,
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
get_linear_schedule_with_warmup,
)
from transformers import glue_convert_examples_to_features as convert_examples_to_features
......@@ -57,16 +51,6 @@ except ImportError:
logger = logging.getLogger(__name__)
ALL_MODELS = sum(
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, DistilBertConfig, XLMConfig)), ()
)
MODEL_CLASSES = {
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
"xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
"distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
}
def set_seed(args):
random.seed(args.seed)
......@@ -377,19 +361,12 @@ def main():
required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
)
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
help="Path to pretrained model or model identifier from huggingface.co/models",
)
parser.add_argument(
"--language",
......@@ -421,7 +398,7 @@ def main():
)
parser.add_argument(
"--cache_dir",
default="",
default=None,
type=str,
help="Where do you want to store the pre-trained models downloaded from s3",
)
......@@ -562,24 +539,23 @@ def main():
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(
config = AutoConfig.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path,
num_labels=num_labels,
finetuning_task=args.task_name,
cache_dir=args.cache_dir if args.cache_dir else None,
cache_dir=args.cache_dir,
)
tokenizer = tokenizer_class.from_pretrained(
args.model_type = config.model_type
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None,
cache_dir=args.cache_dir,
)
model = model_class.from_pretrained(
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
cache_dir=args.cache_dir if args.cache_dir else None,
cache_dir=args.cache_dir,
)
if args.local_rank == 0:
......@@ -614,14 +590,13 @@ def main():
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
model = AutoModelForSequenceClassification.from_pretrained(args.output_dir)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
model.to(args.device)
# Evaluation
results = {}
if args.do_eval and args.local_rank in [-1, 0]:
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
......@@ -633,7 +608,7 @@ def main():
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
model = model_class.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
......
......@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Fine-tuning the library models for named entity recognition on CoNLL-2003 (Bert or Roberta). """
""" Fine-tuning the library models for named entity recognition on CoNLL-2003. """
import logging
......
......@@ -159,7 +159,6 @@ if is_torch_available():
AutoModelWithLMHead,
AutoModelForTokenClassification,
AutoModelForMultipleChoice,
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
MODEL_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
......@@ -180,7 +179,7 @@ if is_torch_available():
BertForTokenClassification,
BertForQuestionAnswering,
load_tf_weights_in_bert,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
BertLayer,
)
from .modeling_openai import (
......@@ -189,7 +188,7 @@ if is_torch_available():
OpenAIGPTLMHeadModel,
OpenAIGPTDoubleHeadsModel,
load_tf_weights_in_openai_gpt,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_transfo_xl import (
TransfoXLPreTrainedModel,
......@@ -197,7 +196,7 @@ if is_torch_available():
TransfoXLLMHeadModel,
AdaptiveEmbedding,
load_tf_weights_in_transfo_xl,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_gpt2 import (
GPT2PreTrainedModel,
......@@ -205,9 +204,9 @@ if is_torch_available():
GPT2LMHeadModel,
GPT2DoubleHeadsModel,
load_tf_weights_in_gpt2,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_ctrl import CTRLPreTrainedModel, CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_ctrl import CTRLPreTrainedModel, CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_LIST
from .modeling_xlnet import (
XLNetPreTrainedModel,
XLNetModel,
......@@ -218,7 +217,7 @@ if is_torch_available():
XLNetForQuestionAnsweringSimple,
XLNetForQuestionAnswering,
load_tf_weights_in_xlnet,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_xlm import (
XLMPreTrainedModel,
......@@ -228,13 +227,13 @@ if is_torch_available():
XLMForTokenClassification,
XLMForQuestionAnswering,
XLMForQuestionAnsweringSimple,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_bart import (
BartForSequenceClassification,
BartModel,
BartForConditionalGeneration,
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_marian import MarianMTModel
from .tokenization_marian import MarianTokenizer
......@@ -245,7 +244,7 @@ if is_torch_available():
RobertaForMultipleChoice,
RobertaForTokenClassification,
RobertaForQuestionAnswering,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_distilbert import (
DistilBertPreTrainedModel,
......@@ -254,7 +253,7 @@ if is_torch_available():
DistilBertForSequenceClassification,
DistilBertForQuestionAnswering,
DistilBertForTokenClassification,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_camembert import (
CamembertForMaskedLM,
......@@ -263,7 +262,7 @@ if is_torch_available():
CamembertForMultipleChoice,
CamembertForTokenClassification,
CamembertForQuestionAnswering,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_encoder_decoder import EncoderDecoderModel
from .modeling_t5 import (
......@@ -271,7 +270,7 @@ if is_torch_available():
T5Model,
T5ForConditionalGeneration,
load_tf_weights_in_t5,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
T5_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_albert import (
AlbertPreTrainedModel,
......@@ -282,7 +281,7 @@ if is_torch_available():
AlbertForQuestionAnswering,
AlbertForTokenClassification,
load_tf_weights_in_albert,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_xlm_roberta import (
XLMRobertaForMaskedLM,
......@@ -290,7 +289,7 @@ if is_torch_available():
XLMRobertaForMultipleChoice,
XLMRobertaForSequenceClassification,
XLMRobertaForTokenClassification,
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_mmbt import ModalEmbeddings, MMBTModel, MMBTForClassification
......@@ -300,7 +299,7 @@ if is_torch_available():
FlaubertForSequenceClassification,
FlaubertForQuestionAnswering,
FlaubertForQuestionAnsweringSimple,
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_electra import (
......@@ -311,7 +310,7 @@ if is_torch_available():
ElectraForSequenceClassification,
ElectraModel,
load_tf_weights_in_electra,
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_reformer import (
......@@ -319,7 +318,7 @@ if is_torch_available():
ReformerLayer,
ReformerModel,
ReformerModelWithLMHead,
REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_longformer import (
......@@ -329,7 +328,7 @@ if is_torch_available():
LongformerForMultipleChoice,
LongformerForTokenClassification,
LongformerForQuestionAnswering,
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
)
# Optimization
......@@ -367,7 +366,6 @@ if is_tf_available():
TFAutoModelForQuestionAnswering,
TFAutoModelWithLMHead,
TFAutoModelForTokenClassification,
TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_MODEL_MAPPING,
TF_MODEL_FOR_PRETRAINING_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
......@@ -388,7 +386,7 @@ if is_tf_available():
TFBertForMultipleChoice,
TFBertForTokenClassification,
TFBertForQuestionAnswering,
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_gpt2 import (
......@@ -397,7 +395,7 @@ if is_tf_available():
TFGPT2Model,
TFGPT2LMHeadModel,
TFGPT2DoubleHeadsModel,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_openai import (
......@@ -406,7 +404,7 @@ if is_tf_available():
TFOpenAIGPTModel,
TFOpenAIGPTLMHeadModel,
TFOpenAIGPTDoubleHeadsModel,
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_transfo_xl import (
......@@ -414,7 +412,7 @@ if is_tf_available():
TFTransfoXLMainLayer,
TFTransfoXLModel,
TFTransfoXLLMHeadModel,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFAdaptiveEmbedding,
)
......@@ -426,7 +424,7 @@ if is_tf_available():
TFXLNetForSequenceClassification,
TFXLNetForTokenClassification,
TFXLNetForQuestionAnsweringSimple,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_xlm import (
......@@ -436,7 +434,7 @@ if is_tf_available():
TFXLMWithLMHeadModel,
TFXLMForSequenceClassification,
TFXLMForQuestionAnsweringSimple,
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_xlm_roberta import (
......@@ -444,7 +442,7 @@ if is_tf_available():
TFXLMRobertaModel,
TFXLMRobertaForSequenceClassification,
TFXLMRobertaForTokenClassification,
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_roberta import (
......@@ -455,7 +453,7 @@ if is_tf_available():
TFRobertaForSequenceClassification,
TFRobertaForTokenClassification,
TFRobertaForQuestionAnswering,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_camembert import (
......@@ -463,14 +461,14 @@ if is_tf_available():
TFCamembertForMaskedLM,
TFCamembertForSequenceClassification,
TFCamembertForTokenClassification,
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_flaubert import (
TFFlaubertModel,
TFFlaubertWithLMHeadModel,
TFFlaubertForSequenceClassification,
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_distilbert import (
......@@ -481,14 +479,14 @@ if is_tf_available():
TFDistilBertForSequenceClassification,
TFDistilBertForTokenClassification,
TFDistilBertForQuestionAnswering,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_ctrl import (
TFCTRLPreTrainedModel,
TFCTRLModel,
TFCTRLLMHeadModel,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_albert import (
......@@ -500,14 +498,14 @@ if is_tf_available():
TFAlbertForMultipleChoice,
TFAlbertForSequenceClassification,
TFAlbertForQuestionAnswering,
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_t5 import (
TFT5PreTrainedModel,
TFT5Model,
TFT5ForConditionalGeneration,
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_electra import (
......@@ -516,7 +514,7 @@ if is_tf_available():
TFElectraForPreTraining,
TFElectraForMaskedLM,
TFElectraForTokenClassification,
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
)
# Optimization
......
......@@ -32,7 +32,7 @@ ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
class AlbertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of an :class:`~transformers.AlbertModel`.
This is the configuration class to store the configuration of a :class:`~transformers.AlbertModel`.
It is used to instantiate an ALBERT model according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
the ALBERT `xxlarge <https://huggingface.co/albert-xxlarge-v2>`__ architecture.
......@@ -97,13 +97,8 @@ class AlbertConfig(PretrainedConfig):
# Accessing the model configuration
configuration = model.config
Attributes:
pretrained_config_archive_map (Dict[str, str]):
A dictionary containing all the available pre-trained checkpoints.
"""
pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "albert"
def __init__(
......
......@@ -113,12 +113,12 @@ class AutoConfig:
)
@classmethod
def for_model(cls, model_type, *args, **kwargs):
for pattern, config_class in CONFIG_MAPPING.items():
if pattern in model_type:
def for_model(cls, model_type: str, *args, **kwargs):
if model_type in CONFIG_MAPPING:
config_class = CONFIG_MAPPING[model_type]
return config_class(*args, **kwargs)
raise ValueError(
"Unrecognized model identifier in {}. Should contain one of {}".format(
"Unrecognized model identifier: {}. Should contain one of {}".format(
model_type, ", ".join(CONFIG_MAPPING.keys())
)
)
......@@ -130,24 +130,24 @@ class AutoConfig:
The configuration class to instantiate is selected
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
- contains `t5`: :class:`~transformers.T5Config` (T5 model)
- contains `distilbert`: :class:`~transformers.DistilBertConfig` (DistilBERT model)
- contains `albert`: :class:`~transformers.AlbertConfig` (ALBERT model)
- contains `camembert`: :class:`~transformers.CamembertConfig` (CamemBERT model)
- contains `xlm-roberta`: :class:`~transformers.XLMRobertaConfig` (XLM-RoBERTa model)
- contains `longformer`: :class:`~transformers.LongformerConfig` (Longformer model)
- contains `roberta`: :class:`~transformers.RobertaConfig` (RoBERTa model)
- contains `reformer`: :class:`~transformers.ReformerConfig` (Reformer model)
- contains `bert`: :class:`~transformers.BertConfig` (Bert model)
- contains `openai-gpt`: :class:`~transformers.OpenAIGPTConfig` (OpenAI GPT model)
- contains `gpt2`: :class:`~transformers.GPT2Config` (OpenAI GPT-2 model)
- contains `transfo-xl`: :class:`~transformers.TransfoXLConfig` (Transformer-XL model)
- contains `xlnet`: :class:`~transformers.XLNetConfig` (XLNet model)
- contains `xlm`: :class:`~transformers.XLMConfig` (XLM model)
- contains `ctrl` : :class:`~transformers.CTRLConfig` (CTRL model)
- contains `flaubert` : :class:`~transformers.FlaubertConfig` (Flaubert model)
- contains `electra` : :class:`~transformers.ElectraConfig` (ELECTRA model)
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
- `t5`: :class:`~transformers.T5Config` (T5 model)
- `distilbert`: :class:`~transformers.DistilBertConfig` (DistilBERT model)
- `albert`: :class:`~transformers.AlbertConfig` (ALBERT model)
- `camembert`: :class:`~transformers.CamembertConfig` (CamemBERT model)
- `xlm-roberta`: :class:`~transformers.XLMRobertaConfig` (XLM-RoBERTa model)
- `longformer`: :class:`~transformers.LongformerConfig` (Longformer model)
- `roberta`: :class:`~transformers.RobertaConfig` (RoBERTa model)
- `reformer`: :class:`~transformers.ReformerConfig` (Reformer model)
- `bert`: :class:`~transformers.BertConfig` (Bert model)
- `openai-gpt`: :class:`~transformers.OpenAIGPTConfig` (OpenAI GPT model)
- `gpt2`: :class:`~transformers.GPT2Config` (OpenAI GPT-2 model)
- `transfo-xl`: :class:`~transformers.TransfoXLConfig` (Transformer-XL model)
- `xlnet`: :class:`~transformers.XLNetConfig` (XLNet model)
- `xlm`: :class:`~transformers.XLMConfig` (XLM model)
- `ctrl` : :class:`~transformers.CTRLConfig` (CTRL model)
- `flaubert` : :class:`~transformers.FlaubertConfig` (Flaubert model)
- `electra` : :class:`~transformers.ElectraConfig` (ELECTRA model)
Args:
pretrained_model_name_or_path (:obj:`string`):
......@@ -193,9 +193,7 @@ class AutoConfig:
assert unused_kwargs == {'foo': False}
"""
config_dict, _ = PretrainedConfig.get_config_dict(
pretrained_model_name_or_path, pretrained_config_archive_map=ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, **kwargs
)
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
if "model_type" in config_dict:
config_class = CONFIG_MAPPING[config_dict["model_type"]]
......
......@@ -23,11 +23,11 @@ from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json",
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
"mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
"facebook/bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json",
"facebook/bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
"facebook/bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
"facebook/bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
}
......@@ -36,7 +36,6 @@ class BartConfig(PretrainedConfig):
Configuration class for Bart. Parameters are renamed from the fairseq implementation
"""
model_type = "bart"
pretrained_config_archive_map = BART_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(
self,
......
......@@ -39,13 +39,14 @@ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
"bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
"bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
"bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json",
"bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json",
"bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json",
"bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json",
"bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
"bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
"bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json",
"cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json",
"cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json",
"cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json",
"cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json",
"TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
"TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
"wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json",
# See all BERT models at https://huggingface.co/models?filter=bert
}
......@@ -102,12 +103,7 @@ class BertConfig(PretrainedConfig):
# Accessing the model configuration
configuration = model.config
Attributes:
pretrained_config_archive_map (Dict[str, str]):
A dictionary containing all the available pre-trained checkpoints.
"""
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "bert"
def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment