Commit b1978698 authored by thomwolf's avatar thomwolf
Browse files

unified tokenizer api and serialization + tests

parent 3d5f2913
...@@ -32,9 +32,11 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -32,9 +32,11 @@ from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME from pytorch_transformers import (BertForSequenceClassification, XLNetForSequenceClassification,
from pytorch_transformers.modeling_bert import BertForSequenceClassification XLMForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
from pytorch_transformers.tokenization_bert import BertTokenizer XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from pytorch_transformers import (BertTokenizer, XLNetTokenizer,
XLMTokenizer)
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
...@@ -42,6 +44,21 @@ from utils_glue import processors, output_modes, convert_examples_to_features, c ...@@ -42,6 +44,21 @@ from utils_glue import processors, output_modes, convert_examples_to_features, c
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(m.keys()) for m in (BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)), ())
MODEL_CLASSES = {
'bert': BertForSequenceClassification,
'xlnet': XLNetForSequenceClassification,
'xlm': XLMForSequenceClassification,
}
TOKENIZER_CLASSES = {
'bert': BertTokenizer,
'xlnet': XLNetTokenizer,
'xlm': XLMTokenizer,
}
def train(args, train_features, model): def train(args, train_features, model):
""" Train the model """ """ Train the model """
...@@ -156,7 +173,7 @@ def evalutate(args, eval_task, eval_output_dir, eval_features, model): ...@@ -156,7 +173,7 @@ def evalutate(args, eval_task, eval_output_dir, eval_features, model):
# Eval! # Eval!
logger.info("***** Running evaluation *****") logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Num examples = %d", len(eval_features))
logger.info(" Batch size = %d", args.eval_batch_size) logger.info(" Batch size = %d", args.eval_batch_size)
model.eval() model.eval()
eval_loss = 0 eval_loss = 0
...@@ -208,7 +225,7 @@ def load_and_cache_examples(args, task, tokenizer, eval=False): ...@@ -208,7 +225,7 @@ def load_and_cache_examples(args, task, tokenizer, eval=False):
examples = processor.get_dev_examples(args.data_dir) examples = processor.get_dev_examples(args.data_dir)
cached_features_file = os.path.join(args.data_dir, '{}_{}_{}_{}'.format( cached_features_file = os.path.join(args.data_dir, '{}_{}_{}_{}'.format(
'dev' if eval else 'train', 'dev' if eval else 'train',
list(filter(None, args.bert_model.split('/'))).pop(), list(filter(None, args.model_name.split('/'))).pop(),
str(args.max_seq_length), str(args.max_seq_length),
str(task))) str(task)))
...@@ -217,6 +234,11 @@ def load_and_cache_examples(args, task, tokenizer, eval=False): ...@@ -217,6 +234,11 @@ def load_and_cache_examples(args, task, tokenizer, eval=False):
features = torch.load(cached_features_file) features = torch.load(cached_features_file)
else: else:
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode) features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode)
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=bool(args.model_type not in ['bert', 'xlm']),
cls_token=tokenizer.cls_token,
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
pad_on_left=True, pad_token_segment_id=4)
if args.local_rank == -1 or torch.distributed.get_rank() == 0: if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info("Saving features into cached file %s", cached_features_file) logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file) torch.save(features, cached_features_file)
...@@ -230,12 +252,10 @@ def main(): ...@@ -230,12 +252,10 @@ def main():
## Required parameters ## Required parameters
parser.add_argument("--data_dir", default=None, type=str, required=True, parser.add_argument("--data_dir", default=None, type=str, required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.") help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--bert_model", default=None, type=str, required=True, parser.add_argument("--model_name", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, " help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
"bert-base-multilingual-cased, bert-base-chinese.")
parser.add_argument("--task_name", default=None, type=str, required=True, parser.add_argument("--task_name", default=None, type=str, required=True,
help="The name of the task to train.") help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
parser.add_argument("--output_dir", default=None, type=str, required=True, parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.") help="The output directory where the model predictions and checkpoints will be written.")
...@@ -243,9 +263,8 @@ def main(): ...@@ -243,9 +263,8 @@ def main():
parser.add_argument("--cache_dir", default="", type=str, parser.add_argument("--cache_dir", default="", type=str,
help="Where do you want to store the pre-trained models downloaded from s3") help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument("--max_seq_length", default=128, type=int, parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n" help="The maximum total input sequence length after tokenization. Sequences longer "
"Sequences longer than this will be truncated, and sequences shorter \n" "than this will be truncated, sequences shorter will be padded.")
"than this will be padded.")
parser.add_argument("--do_train", action='store_true', parser.add_argument("--do_train", action='store_true',
help="Whether to run training.") help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', parser.add_argument("--do_eval", action='store_true',
...@@ -263,8 +282,7 @@ def main(): ...@@ -263,8 +282,7 @@ def main():
parser.add_argument("--num_train_epochs", default=3.0, type=float, parser.add_argument("--num_train_epochs", default=3.0, type=float,
help="Total number of training epochs to perform.") help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion", default=0.1, type=float, parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. " help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).")
"E.g., 0.1 = 10%% of training.")
parser.add_argument("--no_cuda", action='store_true', parser.add_argument("--no_cuda", action='store_true',
help="Avoid using CUDA when available") help="Avoid using CUDA when available")
parser.add_argument('--overwrite_output_dir', action='store_true', parser.add_argument('--overwrite_output_dir', action='store_true',
...@@ -331,8 +349,11 @@ def main(): ...@@ -331,8 +349,11 @@ def main():
# Make sure only the first process in distributed training will download model & vocab # Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier() torch.distributed.barrier()
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) args.model_type = args.model_name.lower().split('-')[0]
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels) args.tokenizer_class = TOKENIZER_CLASSES[args.model_type]
args.model_class = MODEL_CLASSES[args.model_type]
tokenizer = args.tokenizer_class.from_pretrained(args.model_name, do_lower_case=args.do_lower_case)
model = args.model_class.from_pretrained(args.model_name, num_labels=num_labels)
if args.local_rank == 0: if args.local_rank == 0:
torch.distributed.barrier() torch.distributed.barrier()
...@@ -359,27 +380,16 @@ def main(): ...@@ -359,27 +380,16 @@ def main():
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Save a trained model, configuration and tokenizer # Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self model.save_pretrained(args.output_dir)
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(args.output_dir) tokenizer.save_vocabulary(args.output_dir)
# Load a trained model and vocabulary that you have fine-tuned
model = BertForSequenceClassification.from_pretrained(args.output_dir)
tokenizer = BertTokenizer.from_pretrained(args.output_dir)
# Good practice: save your training arguments together with the trained model # Good practice: save your training arguments together with the trained model
output_args_file = os.path.join(args.output_dir, 'training_args.bin') torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
torch.save(args, output_args_file)
else:
model = BertForSequenceClassification.from_pretrained(args.bert_model)
model.to(args.device) # Load a trained model and vocabulary that you have fine-tuned
model = args.model_class.from_pretrained(args.output_dir)
tokenizer = args.tokenizer_class.from_pretrained(args.output_dir)
model.to(args.device)
# Evaluation # Evaluation
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
......
...@@ -211,8 +211,8 @@ def main(): ...@@ -211,8 +211,8 @@ def main():
logger.info("No cache file at %s, preparing train features", cached_train_features_file) logger.info("No cache file at %s, preparing train features", cached_train_features_file)
train_features = convert_examples_to_features( train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer, output_mode, train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN, cls_token_at_end=True, cls_token=tokenizer.cls_token,
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2, sep_token=tokenizer.sep_token, cls_token_segment_id=2,
pad_on_left=True, pad_token_segment_id=4) pad_on_left=True, pad_token_segment_id=4)
if args.local_rank == -1 or torch.distributed.get_rank() == 0: if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving train features into cached file %s", cached_train_features_file) logger.info(" Saving train features into cached file %s", cached_train_features_file)
...@@ -369,8 +369,8 @@ def main(): ...@@ -369,8 +369,8 @@ def main():
logger.info("No cache file at %s, preparing eval features", cached_eval_features_file) logger.info("No cache file at %s, preparing eval features", cached_eval_features_file)
eval_features = convert_examples_to_features( eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode, eval_examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN, cls_token_at_end=True, cls_token=tokenizer.cls_token,
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2, sep_token=tokenizer.sep_token, cls_token_segment_id=2,
pad_on_left=True, pad_token_segment_id=4) pad_on_left=True, pad_token_segment_id=4)
if args.local_rank == -1 or torch.distributed.get_rank() == 0: if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving eval features into cached file %s", cached_eval_features_file) logger.info(" Saving eval features into cached file %s", cached_eval_features_file)
......
...@@ -396,7 +396,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length, ...@@ -396,7 +396,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
mask_padding_with_zero=True): mask_padding_with_zero=True):
""" Loads a data file into a list of `InputBatch`s """ Loads a data file into a list of `InputBatch`s
`cls_token_at_end` define the location of the CLS token: `cls_token_at_end` define the location of the CLS token:
- False (BERT pattern): [CLS] + A + [SEP] + B + [SEP] - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
""" """
...@@ -489,8 +489,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length, ...@@ -489,8 +489,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
[str(x) for x in tokens])) [str(x) for x in tokens]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info( logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
logger.info("label: %s (id = %d)" % (example.label, label_id)) logger.info("label: %s (id = %d)" % (example.label, label_id))
features.append( features.append(
......
...@@ -11,22 +11,28 @@ from .modeling_bert import (BertConfig, BertModel, BertForPreTraining, ...@@ -11,22 +11,28 @@ from .modeling_bert import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction, BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForMultipleChoice, BertForSequenceClassification, BertForMultipleChoice,
BertForTokenClassification, BertForQuestionAnswering, BertForTokenClassification, BertForQuestionAnswering,
load_tf_weights_in_bert) load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP)
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
load_tf_weights_in_openai_gpt) load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl) load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_gpt2 import (GPT2Config, GPT2Model, from .modeling_gpt2 import (GPT2Config, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2DoubleHeadsModel,
load_tf_weights_in_gpt2) load_tf_weights_in_gpt2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlnet import (XLNetConfig, from .modeling_xlnet import (XLNetConfig,
XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel, XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForQuestionAnswering, XLNetForSequenceClassification, XLNetForQuestionAnswering,
load_tf_weights_in_xlnet) load_tf_weights_in_xlnet, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlm import (XLMConfig, XLMModel, from .modeling_xlm import (XLMConfig, XLMModel,
XLMWithLMHeadModel, XLMForSequenceClassification, XLMWithLMHeadModel, XLMForSequenceClassification,
XLMForQuestionAnswering) XLMForQuestionAnswering, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D) PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
......
...@@ -29,8 +29,7 @@ from pytorch_transformers.modeling_transfo_xl import (CONFIG_NAME, ...@@ -29,8 +29,7 @@ from pytorch_transformers.modeling_transfo_xl import (CONFIG_NAME,
TransfoXLConfig, TransfoXLConfig,
TransfoXLLMHeadModel, TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl) load_tf_weights_in_transfo_xl)
from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES)
VOCAB_NAME)
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import cPickle as pickle import cPickle as pickle
...@@ -53,7 +52,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, ...@@ -53,7 +52,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
with open(transfo_xl_dataset_file, "rb") as fp: with open(transfo_xl_dataset_file, "rb") as fp:
corpus = pickle.load(fp, encoding="latin1") corpus = pickle.load(fp, encoding="latin1")
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file']
print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
corpus_vocab_dict = corpus.vocab.__dict__ corpus_vocab_dict = corpus.vocab.__dict__
torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
......
...@@ -24,7 +24,7 @@ import torch ...@@ -24,7 +24,7 @@ import torch
import numpy import numpy
from pytorch_transformers.modeling_xlm import (CONFIG_NAME, WEIGHTS_NAME, XLMConfig, XLMModel) from pytorch_transformers.modeling_xlm import (CONFIG_NAME, WEIGHTS_NAME, XLMConfig, XLMModel)
from pytorch_transformers.tokenization_xlm import MERGES_NAME, VOCAB_NAME from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
...@@ -42,7 +42,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p ...@@ -42,7 +42,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
# Save pytorch-model # Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file']
print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model, pytorch_weights_dump_path) torch.save(model, pytorch_weights_dump_path)
......
...@@ -33,7 +33,7 @@ from .modeling_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrai ...@@ -33,7 +33,7 @@ from .modeling_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrai
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = { BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin", 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin", 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin", 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
...@@ -49,7 +49,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -49,7 +49,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin", 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
} }
PRETRAINED_CONFIG_ARCHIVE_MAP = { BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
...@@ -152,7 +152,7 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} ...@@ -152,7 +152,7 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class BertConfig(PretrainedConfig): class BertConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `BertModel`. """Configuration class to store the configuration of a `BertModel`.
""" """
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(self,
vocab_size_or_config_json_file=30522, vocab_size_or_config_json_file=30522,
...@@ -543,7 +543,7 @@ class BertPreTrainedModel(PreTrainedModel): ...@@ -543,7 +543,7 @@ class BertPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = BertConfig config_class = BertConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_bert load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert" base_model_prefix = "bert"
......
...@@ -37,9 +37,9 @@ from .modeling_bert import BertLayerNorm as LayerNorm ...@@ -37,9 +37,9 @@ from .modeling_bert import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin", GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"} "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"} "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
...@@ -103,7 +103,7 @@ def gelu(x): ...@@ -103,7 +103,7 @@ def gelu(x):
class GPT2Config(PretrainedConfig): class GPT2Config(PretrainedConfig):
"""Configuration class to store the configuration of a `GPT2Model`. """Configuration class to store the configuration of a `GPT2Model`.
""" """
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__( def __init__(
self, self,
...@@ -358,7 +358,7 @@ class GPT2PreTrainedModel(PreTrainedModel): ...@@ -358,7 +358,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = GPT2Config config_class = GPT2Config
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_gpt2 load_tf_weights = load_tf_weights_in_gpt2
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -37,8 +37,8 @@ from .modeling_bert import BertLayerNorm as LayerNorm ...@@ -37,8 +37,8 @@ from .modeling_bert import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"} OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"} OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path): def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
...@@ -130,7 +130,7 @@ ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu} ...@@ -130,7 +130,7 @@ ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}
class OpenAIGPTConfig(PretrainedConfig): class OpenAIGPTConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `OpenAIGPTModel`. """Configuration class to store the configuration of a `OpenAIGPTModel`.
""" """
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__( def __init__(
self, self,
...@@ -384,7 +384,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel): ...@@ -384,7 +384,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = OpenAIGPTConfig config_class = OpenAIGPTConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_openai_gpt load_tf_weights = load_tf_weights_in_openai_gpt
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -41,10 +41,10 @@ from .modeling_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrai ...@@ -41,10 +41,10 @@ from .modeling_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrai
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = { TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin", 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin",
} }
PRETRAINED_CONFIG_ARCHIVE_MAP = { TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json", 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
} }
...@@ -179,7 +179,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path): ...@@ -179,7 +179,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
class TransfoXLConfig(PretrainedConfig): class TransfoXLConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `TransfoXLModel`. """Configuration class to store the configuration of a `TransfoXLModel`.
""" """
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(self,
vocab_size_or_config_json_file=267735, vocab_size_or_config_json_file=267735,
...@@ -838,7 +838,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel): ...@@ -838,7 +838,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = TransfoXLConfig config_class = TransfoXLConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_transfo_xl load_tf_weights = load_tf_weights_in_transfo_xl
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -169,6 +169,22 @@ class PreTrainedModel(nn.Module): ...@@ -169,6 +169,22 @@ class PreTrainedModel(nn.Module):
model_to_prune = getattr(self, self.base_model_prefix, self) # get the base model if needed model_to_prune = getattr(self, self.base_model_prefix, self) # get the base model if needed
model_to_prune._prune_heads(heads_to_prune) model_to_prune._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory):
""" Save a model with its configuration file to a directory, so that it
can be re-loaded using the `from_pretrained(save_directory)` class method.
"""
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
# Only save the model it-self if we are using distributed training
model_to_save = self.module if hasattr(self, 'module') else self
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
output_config_file = os.path.join(save_directory, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
""" """
......
...@@ -40,10 +40,10 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra ...@@ -40,10 +40,10 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = { XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin", 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin",
} }
PRETRAINED_CONFIG_ARCHIVE_MAP = { XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json", 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
} }
...@@ -51,7 +51,7 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -51,7 +51,7 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
class XLMConfig(PretrainedConfig): class XLMConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `XLMModel`. """Configuration class to store the configuration of a `XLMModel`.
""" """
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(self,
vocab_size_or_config_json_file=30145, vocab_size_or_config_json_file=30145,
...@@ -357,7 +357,7 @@ class XLMPreTrainedModel(PreTrainedModel): ...@@ -357,7 +357,7 @@ class XLMPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = XLMConfig config_class = XLMConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = XLM_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = None load_tf_weights = None
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -38,10 +38,10 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra ...@@ -38,10 +38,10 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = { XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin", 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin",
} }
PRETRAINED_CONFIG_ARCHIVE_MAP = { XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json", 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
} }
...@@ -195,7 +195,7 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} ...@@ -195,7 +195,7 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class XLNetConfig(PretrainedConfig): class XLNetConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `XLNetModel`. """Configuration class to store the configuration of a `XLNetModel`.
""" """
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(self,
vocab_size_or_config_json_file=32000, vocab_size_or_config_json_file=32000,
...@@ -593,7 +593,7 @@ class XLNetPreTrainedModel(PreTrainedModel): ...@@ -593,7 +593,7 @@ class XLNetPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = XLNetConfig config_class = XLNetConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_xlnet load_tf_weights = load_tf_weights_in_xlnet
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -24,7 +24,7 @@ from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM, ...@@ -24,7 +24,7 @@ from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining, BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification, BertForMultipleChoice) BertForTokenClassification, BertForMultipleChoice)
from pytorch_transformers.modeling_bert import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor) from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
...@@ -267,7 +267,7 @@ class BertModelTest(unittest.TestCase): ...@@ -267,7 +267,7 @@ class BertModelTest(unittest.TestCase):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir) model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -413,7 +413,7 @@ class GPTModelTester(object): ...@@ -413,7 +413,7 @@ class GPTModelTester(object):
def create_and_check_model_from_pretrained(self): def create_and_check_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(self.base_model_class.PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]:
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir) model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
self.parent.assertIsNotNone(model) self.parent.assertIsNotNone(model)
......
...@@ -26,7 +26,7 @@ import pytest ...@@ -26,7 +26,7 @@ import pytest
import torch import torch
from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel) from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
from pytorch_transformers.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
...@@ -185,7 +185,7 @@ class TransfoXLModelTest(unittest.TestCase): ...@@ -185,7 +185,7 @@ class TransfoXLModelTest(unittest.TestCase):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir) model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -20,12 +20,12 @@ import unittest ...@@ -20,12 +20,12 @@ import unittest
import logging import logging
from pytorch_transformers import PretrainedConfig, PreTrainedModel from pytorch_transformers import PretrainedConfig, PreTrainedModel
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
class ModelUtilsTest(unittest.TestCase): class ModelUtilsTest(unittest.TestCase):
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
config = BertConfig.from_pretrained(model_name) config = BertConfig.from_pretrained(model_name)
self.assertIsNotNone(config) self.assertIsNotNone(config)
self.assertIsInstance(config, PretrainedConfig) self.assertIsInstance(config, PretrainedConfig)
......
...@@ -21,7 +21,7 @@ import shutil ...@@ -21,7 +21,7 @@ import shutil
import pytest import pytest
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification) from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
from pytorch_transformers.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor) from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
...@@ -251,7 +251,7 @@ class XLMModelTest(unittest.TestCase): ...@@ -251,7 +251,7 @@ class XLMModelTest(unittest.TestCase):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir) model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -26,7 +26,7 @@ import pytest ...@@ -26,7 +26,7 @@ import pytest
import torch import torch
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering) from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
from pytorch_transformers.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
...@@ -279,7 +279,7 @@ class XLNetModelTest(unittest.TestCase): ...@@ -279,7 +279,7 @@ class XLNetModelTest(unittest.TestCase):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir) model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -17,14 +17,12 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,14 +17,12 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
from io import open from io import open
import shutil
import pytest
from pytorch_transformers.tokenization_bert import (BasicTokenizer, from pytorch_transformers.tokenization_bert import (BasicTokenizer,
BertTokenizer, BertTokenizer,
WordpieceTokenizer, WordpieceTokenizer,
_is_control, _is_punctuation, _is_control, _is_punctuation,
_is_whitespace) _is_whitespace, VOCAB_FILES_NAMES)
from .tokenization_tests_commons import create_and_check_tokenizer_commons from .tokenization_tests_commons import create_and_check_tokenizer_commons
...@@ -33,13 +31,15 @@ class TokenizationTest(unittest.TestCase): ...@@ -33,13 +31,15 @@ class TokenizationTest(unittest.TestCase):
def test_full_tokenizer(self): def test_full_tokenizer(self):
vocab_tokens = [ vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", "," "##ing", ",", "low", "lowest",
] ]
with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer: vocab_directory = "/tmp/"
vocab_file = os.path.join(vocab_directory, VOCAB_FILES_NAMES['vocab_file'])
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name vocab_file = vocab_writer.name
create_and_check_tokenizer_commons(self, BertTokenizer, vocab_file) create_and_check_tokenizer_commons(self, BertTokenizer, pretrained_model_name_or_path=vocab_directory)
tokenizer = BertTokenizer(vocab_file) tokenizer = BertTokenizer(vocab_file)
...@@ -80,7 +80,7 @@ class TokenizationTest(unittest.TestCase): ...@@ -80,7 +80,7 @@ class TokenizationTest(unittest.TestCase):
vocab = {} vocab = {}
for (i, token) in enumerate(vocab_tokens): for (i, token) in enumerate(vocab_tokens):
vocab[token] = i vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab) tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
self.assertListEqual(tokenizer.tokenize(""), []) self.assertListEqual(tokenizer.tokenize(""), [])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment