Commit b1978698 authored by thomwolf's avatar thomwolf
Browse files

unified tokenizer api and serialization + tests

parent 3d5f2913
......@@ -32,9 +32,11 @@ from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
from pytorch_transformers.modeling_bert import BertForSequenceClassification
from pytorch_transformers.tokenization_bert import BertTokenizer
from pytorch_transformers import (BertForSequenceClassification, XLNetForSequenceClassification,
XLMForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from pytorch_transformers import (BertTokenizer, XLNetTokenizer,
XLMTokenizer)
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
......@@ -42,6 +44,21 @@ from utils_glue import processors, output_modes, convert_examples_to_features, c
logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(m.keys()) for m in (BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)), ())
MODEL_CLASSES = {
'bert': BertForSequenceClassification,
'xlnet': XLNetForSequenceClassification,
'xlm': XLMForSequenceClassification,
}
TOKENIZER_CLASSES = {
'bert': BertTokenizer,
'xlnet': XLNetTokenizer,
'xlm': XLMTokenizer,
}
def train(args, train_features, model):
""" Train the model """
......@@ -156,7 +173,7 @@ def evalutate(args, eval_task, eval_output_dir, eval_features, model):
# Eval!
logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Num examples = %d", len(eval_features))
logger.info(" Batch size = %d", args.eval_batch_size)
model.eval()
eval_loss = 0
......@@ -208,7 +225,7 @@ def load_and_cache_examples(args, task, tokenizer, eval=False):
examples = processor.get_dev_examples(args.data_dir)
cached_features_file = os.path.join(args.data_dir, '{}_{}_{}_{}'.format(
'dev' if eval else 'train',
list(filter(None, args.bert_model.split('/'))).pop(),
list(filter(None, args.model_name.split('/'))).pop(),
str(args.max_seq_length),
str(task)))
......@@ -217,6 +234,11 @@ def load_and_cache_examples(args, task, tokenizer, eval=False):
features = torch.load(cached_features_file)
else:
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode)
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=bool(args.model_type not in ['bert', 'xlm']),
cls_token=tokenizer.cls_token,
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
pad_on_left=True, pad_token_segment_id=4)
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
......@@ -230,12 +252,10 @@ def main():
## Required parameters
parser.add_argument("--data_dir", default=None, type=str, required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
"bert-base-multilingual-cased, bert-base-chinese.")
parser.add_argument("--model_name", default=None, type=str, required=True,
help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--task_name", default=None, type=str, required=True,
help="The name of the task to train.")
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")
......@@ -243,9 +263,8 @@ def main():
parser.add_argument("--cache_dir", default="", type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.")
parser.add_argument("--do_train", action='store_true',
help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true',
......@@ -263,8 +282,7 @@ def main():
parser.add_argument("--num_train_epochs", default=3.0, type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10%% of training.")
help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).")
parser.add_argument("--no_cuda", action='store_true',
help="Avoid using CUDA when available")
parser.add_argument('--overwrite_output_dir', action='store_true',
......@@ -331,8 +349,11 @@ def main():
# Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier()
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
args.model_type = args.model_name.lower().split('-')[0]
args.tokenizer_class = TOKENIZER_CLASSES[args.model_type]
args.model_class = MODEL_CLASSES[args.model_type]
tokenizer = args.tokenizer_class.from_pretrained(args.model_name, do_lower_case=args.do_lower_case)
model = args.model_class.from_pretrained(args.model_name, num_labels=num_labels)
if args.local_rank == 0:
torch.distributed.barrier()
......@@ -359,27 +380,16 @@ def main():
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
model.save_pretrained(args.output_dir)
tokenizer.save_vocabulary(args.output_dir)
# Load a trained model and vocabulary that you have fine-tuned
model = BertForSequenceClassification.from_pretrained(args.output_dir)
tokenizer = BertTokenizer.from_pretrained(args.output_dir)
# Good practice: save your training arguments together with the trained model
output_args_file = os.path.join(args.output_dir, 'training_args.bin')
torch.save(args, output_args_file)
else:
model = BertForSequenceClassification.from_pretrained(args.bert_model)
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
model.to(args.device)
# Load a trained model and vocabulary that you have fine-tuned
model = args.model_class.from_pretrained(args.output_dir)
tokenizer = args.tokenizer_class.from_pretrained(args.output_dir)
model.to(args.device)
# Evaluation
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
......
......@@ -211,8 +211,8 @@ def main():
logger.info("No cache file at %s, preparing train features", cached_train_features_file)
train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN,
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2,
cls_token_at_end=True, cls_token=tokenizer.cls_token,
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
pad_on_left=True, pad_token_segment_id=4)
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving train features into cached file %s", cached_train_features_file)
......@@ -369,8 +369,8 @@ def main():
logger.info("No cache file at %s, preparing eval features", cached_eval_features_file)
eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN,
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2,
cls_token_at_end=True, cls_token=tokenizer.cls_token,
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
pad_on_left=True, pad_token_segment_id=4)
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving eval features into cached file %s", cached_eval_features_file)
......
......@@ -396,7 +396,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
mask_padding_with_zero=True):
""" Loads a data file into a list of `InputBatch`s
`cls_token_at_end` define the location of the CLS token:
- False (BERT pattern): [CLS] + A + [SEP] + B + [SEP]
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
"""
......@@ -489,8 +489,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
[str(x) for x in tokens]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info(
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
logger.info("label: %s (id = %d)" % (example.label, label_id))
features.append(
......
......@@ -11,22 +11,28 @@ from .modeling_bert import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForMultipleChoice,
BertForTokenClassification, BertForQuestionAnswering,
load_tf_weights_in_bert)
load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP)
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
load_tf_weights_in_openai_gpt)
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl)
load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_gpt2 import (GPT2Config, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel,
load_tf_weights_in_gpt2)
load_tf_weights_in_gpt2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlnet import (XLNetConfig,
XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForQuestionAnswering,
load_tf_weights_in_xlnet)
load_tf_weights_in_xlnet, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlm import (XLMConfig, XLMModel,
XLMWithLMHeadModel, XLMForSequenceClassification,
XLMForQuestionAnswering)
XLMForQuestionAnswering, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
......
......@@ -29,8 +29,7 @@ from pytorch_transformers.modeling_transfo_xl import (CONFIG_NAME,
TransfoXLConfig,
TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl)
from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME,
VOCAB_NAME)
from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES)
if sys.version_info[0] == 2:
import cPickle as pickle
......@@ -53,7 +52,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
with open(transfo_xl_dataset_file, "rb") as fp:
corpus = pickle.load(fp, encoding="latin1")
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file']
print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
corpus_vocab_dict = corpus.vocab.__dict__
torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
......
......@@ -24,7 +24,7 @@ import torch
import numpy
from pytorch_transformers.modeling_xlm import (CONFIG_NAME, WEIGHTS_NAME, XLMConfig, XLMModel)
from pytorch_transformers.tokenization_xlm import MERGES_NAME, VOCAB_NAME
from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
......@@ -42,7 +42,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
# Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file']
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model, pytorch_weights_dump_path)
......
......@@ -33,7 +33,7 @@ from .modeling_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrai
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
......@@ -49,7 +49,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
}
PRETRAINED_CONFIG_ARCHIVE_MAP = {
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
......@@ -152,7 +152,7 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class BertConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `BertModel`.
"""
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=30522,
......@@ -543,7 +543,7 @@ class BertPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models.
"""
config_class = BertConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
......
......@@ -37,9 +37,9 @@ from .modeling_bert import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
......@@ -103,7 +103,7 @@ def gelu(x):
class GPT2Config(PretrainedConfig):
"""Configuration class to store the configuration of a `GPT2Model`.
"""
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(
self,
......@@ -358,7 +358,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models.
"""
config_class = GPT2Config
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map = GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_gpt2
base_model_prefix = "transformer"
......
......@@ -37,8 +37,8 @@ from .modeling_bert import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
......@@ -130,7 +130,7 @@ ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}
class OpenAIGPTConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `OpenAIGPTModel`.
"""
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(
self,
......@@ -384,7 +384,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models.
"""
config_class = OpenAIGPTConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map = OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_openai_gpt
base_model_prefix = "transformer"
......
......@@ -41,10 +41,10 @@ from .modeling_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrai
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin",
}
PRETRAINED_CONFIG_ARCHIVE_MAP = {
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
}
......@@ -179,7 +179,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
class TransfoXLConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `TransfoXLModel`.
"""
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=267735,
......@@ -838,7 +838,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models.
"""
config_class = TransfoXLConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_transfo_xl
base_model_prefix = "transformer"
......
......@@ -169,6 +169,22 @@ class PreTrainedModel(nn.Module):
model_to_prune = getattr(self, self.base_model_prefix, self) # get the base model if needed
model_to_prune._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory):
""" Save a model with its configuration file to a directory, so that it
can be re-loaded using the `from_pretrained(save_directory)` class method.
"""
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
# Only save the model it-self if we are using distributed training
model_to_save = self.module if hasattr(self, 'module') else self
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
output_config_file = os.path.join(save_directory, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
......
......@@ -40,10 +40,10 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin",
}
PRETRAINED_CONFIG_ARCHIVE_MAP = {
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
}
......@@ -51,7 +51,7 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
class XLMConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `XLMModel`.
"""
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=30145,
......@@ -357,7 +357,7 @@ class XLMPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models.
"""
config_class = XLMConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map = XLM_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = None
base_model_prefix = "transformer"
......
......@@ -38,10 +38,10 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin",
}
PRETRAINED_CONFIG_ARCHIVE_MAP = {
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
}
......@@ -195,7 +195,7 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class XLNetConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `XLNetModel`.
"""
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=32000,
......@@ -593,7 +593,7 @@ class XLNetPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models.
"""
config_class = XLNetConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map = XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_xlnet
base_model_prefix = "transformer"
......
......@@ -24,7 +24,7 @@ from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification, BertForMultipleChoice)
from pytorch_transformers.modeling_bert import PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
......@@ -267,7 +267,7 @@ class BertModelTest(unittest.TestCase):
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
......
......@@ -413,7 +413,7 @@ class GPTModelTester(object):
def create_and_check_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(self.base_model_class.PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]:
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.parent.assertIsNotNone(model)
......
......@@ -26,7 +26,7 @@ import pytest
import torch
from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
from pytorch_transformers.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
......@@ -185,7 +185,7 @@ class TransfoXLModelTest(unittest.TestCase):
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
......
......@@ -20,12 +20,12 @@ import unittest
import logging
from pytorch_transformers import PretrainedConfig, PreTrainedModel
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
class ModelUtilsTest(unittest.TestCase):
def test_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
config = BertConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, PretrainedConfig)
......
......@@ -21,7 +21,7 @@ import shutil
import pytest
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
from pytorch_transformers.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
......@@ -251,7 +251,7 @@ class XLMModelTest(unittest.TestCase):
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
......
......@@ -26,7 +26,7 @@ import pytest
import torch
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
from pytorch_transformers.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
......@@ -279,7 +279,7 @@ class XLNetModelTest(unittest.TestCase):
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
......
......@@ -17,14 +17,12 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import unittest
from io import open
import shutil
import pytest
from pytorch_transformers.tokenization_bert import (BasicTokenizer,
BertTokenizer,
WordpieceTokenizer,
_is_control, _is_punctuation,
_is_whitespace)
BertTokenizer,
WordpieceTokenizer,
_is_control, _is_punctuation,
_is_whitespace, VOCAB_FILES_NAMES)
from .tokenization_tests_commons import create_and_check_tokenizer_commons
......@@ -33,13 +31,15 @@ class TokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", ","
"##ing", ",", "low", "lowest",
]
with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
vocab_directory = "/tmp/"
vocab_file = os.path.join(vocab_directory, VOCAB_FILES_NAMES['vocab_file'])
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name
create_and_check_tokenizer_commons(self, BertTokenizer, vocab_file)
create_and_check_tokenizer_commons(self, BertTokenizer, pretrained_model_name_or_path=vocab_directory)
tokenizer = BertTokenizer(vocab_file)
......@@ -80,7 +80,7 @@ class TokenizationTest(unittest.TestCase):
vocab = {}
for (i, token) in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab)
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
self.assertListEqual(tokenizer.tokenize(""), [])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment