"...resnet50_tensorflow.git" did not exist on "e54c17f711b09a19371f24574188942b67426803"
Commit fa84ae26 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Reformat source code with black.

This is the result of:

    $ black --line-length 119 examples templates transformers utils hubconf.py setup.py

There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.

This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
parent 63e3827c
...@@ -24,82 +24,270 @@ import tensorflow as tf ...@@ -24,82 +24,270 @@ import tensorflow as tf
from transformers import is_torch_available, cached_path from transformers import is_torch_available, cached_path
from transformers import (load_pytorch_checkpoint_in_tf2_model, from transformers import (
BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, load_pytorch_checkpoint_in_tf2_model,
GPT2Config, TFGPT2LMHeadModel, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig,
XLNetConfig, TFXLNetLMHeadModel, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, TFBertForPreTraining,
XLMConfig, TFXLMWithLMHeadModel, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, TFBertForQuestionAnswering,
TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TFBertForSequenceClassification,
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config,
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, TFDistilBertForSequenceClassification, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, TFGPT2LMHeadModel,
CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
AlbertConfig, TFAlbertForMaskedLM, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig,
T5Config, TFT5WithLMHeadModel, T5_PRETRAINED_CONFIG_ARCHIVE_MAP) TFXLNetLMHeadModel,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLMConfig,
TFXLMWithLMHeadModel,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
TransfoXLConfig,
TFTransfoXLLMHeadModel,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
OpenAIGPTConfig,
TFOpenAIGPTLMHeadModel,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
RobertaConfig,
TFRobertaForMaskedLM,
TFRobertaForSequenceClassification,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
DistilBertConfig,
TFDistilBertForMaskedLM,
TFDistilBertForQuestionAnswering,
TFDistilBertForSequenceClassification,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRLConfig,
TFCTRLLMHeadModel,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
AlbertConfig,
TFAlbertForMaskedLM,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5Config,
TFT5WithLMHeadModel,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
)
if is_torch_available(): if is_torch_available():
import torch import torch
import numpy as np import numpy as np
from transformers import (BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, from transformers import (
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, BertForPreTraining,
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, BertForQuestionAnswering,
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, BertForSequenceClassification,
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, XLNetLMHeadModel,
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, XLMWithLMHeadModel,
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP) XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
TransfoXLLMHeadModel,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM,
RobertaForSequenceClassification,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM,
DistilBertForQuestionAnswering,
DistilBertForSequenceClassification,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5WithLMHeadModel,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
)
else: else:
(BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, (
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, BertForPreTraining,
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, BertForQuestionAnswering,
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, BertForSequenceClassification,
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForSequenceClassification, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, XLNetLMHeadModel,
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, XLMWithLMHeadModel,
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP) = ( XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
None, None, None, None, TransfoXLLMHeadModel,
None, None, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
None, None, OpenAIGPTLMHeadModel,
None, None, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
None, None, RobertaForMaskedLM,
None, None, RobertaForSequenceClassification,
None, None, None, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
None, None, None, None, DistilBertForMaskedLM,
None, None, DistilBertForSequenceClassification,
None, None, DistilBertForQuestionAnswering,
None, None) DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5WithLMHeadModel,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
) = (
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
import logging import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), "bert": (
'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), BertConfig,
'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), TFBertForPreTraining,
'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), BertForPreTraining,
'gpt2': (GPT2Config, TFGPT2LMHeadModel, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP), BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP), BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
'xlm': (XLMConfig, TFXLMWithLMHeadModel, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP), ),
'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP), "bert-large-uncased-whole-word-masking-finetuned-squad": (
'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP), BertConfig,
'roberta': (RobertaConfig, TFRobertaForMaskedLM, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), TFBertForQuestionAnswering,
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), BertForQuestionAnswering,
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), ),
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP), "bert-large-cased-whole-word-masking-finetuned-squad": (
'albert': (AlbertConfig, TFAlbertForMaskedLM, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), BertConfig,
't5': (T5Config, TFT5WithLMHeadModel, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP), TFBertForQuestionAnswering,
BertForQuestionAnswering,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"bert-base-cased-finetuned-mrpc": (
BertConfig,
TFBertForSequenceClassification,
BertForSequenceClassification,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"gpt2": (
GPT2Config,
TFGPT2LMHeadModel,
GPT2LMHeadModel,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"xlnet": (
XLNetConfig,
TFXLNetLMHeadModel,
XLNetLMHeadModel,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"xlm": (
XLMConfig,
TFXLMWithLMHeadModel,
XLMWithLMHeadModel,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"transfo-xl": (
TransfoXLConfig,
TFTransfoXLLMHeadModel,
TransfoXLLMHeadModel,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"openai-gpt": (
OpenAIGPTConfig,
TFOpenAIGPTLMHeadModel,
OpenAIGPTLMHeadModel,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"roberta": (
RobertaConfig,
TFRobertaForMaskedLM,
RobertaForMaskedLM,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"roberta-large-mnli": (
RobertaConfig,
TFRobertaForSequenceClassification,
RobertaForSequenceClassification,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"distilbert": (
DistilBertConfig,
TFDistilBertForMaskedLM,
DistilBertForMaskedLM,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"distilbert-base-uncased-distilled-squad": (
DistilBertConfig,
TFDistilBertForQuestionAnswering,
DistilBertForQuestionAnswering,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"distilbert-base-uncased-distilled-squad": (
DistilBertConfig,
TFDistilBertForQuestionAnswering,
DistilBertForQuestionAnswering,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"ctrl": (
CTRLConfig,
TFCTRLLMHeadModel,
CTRLLMHeadModel,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"albert": (
AlbertConfig,
TFAlbertForMaskedLM,
AlbertForMaskedLM,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"t5": (
T5Config,
TFT5WithLMHeadModel,
T5WithLMHeadModel,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
} }
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
def convert_pt_checkpoint_to_tf(
model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True
):
if model_type not in MODEL_CLASSES: if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys()))) raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
...@@ -116,17 +304,19 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file ...@@ -116,17 +304,19 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
# Load weights from tf checkpoint # Load weights from tf checkpoint
if pytorch_checkpoint_path in aws_model_maps: if pytorch_checkpoint_path in aws_model_maps:
pytorch_checkpoint_path = cached_path(aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models) pytorch_checkpoint_path = cached_path(
aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models
)
# Load PyTorch checkpoint in tf2 model: # Load PyTorch checkpoint in tf2 model:
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path) tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
if compare_with_pt_model: if compare_with_pt_model:
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
state_dict = torch.load(pytorch_checkpoint_path, map_location='cpu') state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
pt_model = pt_model_class.from_pretrained(pretrained_model_name_or_path=None, pt_model = pt_model_class.from_pretrained(
config=config, pretrained_model_name_or_path=None, config=config, state_dict=state_dict
state_dict=state_dict) )
with torch.no_grad(): with torch.no_grad():
pto = pt_model(**pt_model.dummy_inputs) pto = pt_model(**pt_model.dummy_inputs)
...@@ -139,11 +329,19 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file ...@@ -139,11 +329,19 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
# Save pytorch-model # Save pytorch-model
print("Save TensorFlow model to {}".format(tf_dump_path)) print("Save TensorFlow model to {}".format(tf_dump_path))
tf_model.save_weights(tf_dump_path, save_format='h5') tf_model.save_weights(tf_dump_path, save_format="h5")
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None, def convert_all_pt_checkpoints_to_tf(
compare_with_pt_model=False, use_cached_models=False, remove_cached_files=False, only_convert_finetuned_models=False): args_model_type,
tf_dump_path,
model_shortcut_names_or_path=None,
config_shortcut_names_or_path=None,
compare_with_pt_model=False,
use_cached_models=False,
remove_cached_files=False,
only_convert_finetuned_models=False,
):
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory" assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
if args_model_type is None: if args_model_type is None:
...@@ -156,7 +354,9 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc ...@@ -156,7 +354,9 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
print(" Converting model type {}/{}: {}".format(j, len(model_types), model_type)) print(" Converting model type {}/{}: {}".format(j, len(model_types), model_type))
print("=" * 100) print("=" * 100)
if model_type not in MODEL_CLASSES: if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys()))) raise ValueError(
"Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys()))
)
config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type] config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
...@@ -166,9 +366,10 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc ...@@ -166,9 +366,10 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
config_shortcut_names_or_path = model_shortcut_names_or_path config_shortcut_names_or_path = model_shortcut_names_or_path
for i, (model_shortcut_name, config_shortcut_name) in enumerate( for i, (model_shortcut_name, config_shortcut_name) in enumerate(
zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1): zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1
):
print("-" * 100) print("-" * 100)
if '-squad' in model_shortcut_name or '-mrpc' in model_shortcut_name or '-mnli' in model_shortcut_name: if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name:
if not only_convert_finetuned_models: if not only_convert_finetuned_models:
print(" Skipping finetuned checkpoint {}".format(model_shortcut_name)) print(" Skipping finetuned checkpoint {}".format(model_shortcut_name))
continue continue
...@@ -176,7 +377,11 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc ...@@ -176,7 +377,11 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
elif only_convert_finetuned_models: elif only_convert_finetuned_models:
print(" Skipping not finetuned checkpoint {}".format(model_shortcut_name)) print(" Skipping not finetuned checkpoint {}".format(model_shortcut_name))
continue continue
print(" Converting checkpoint {}/{}: {} - model_type {}".format(i, len(aws_config_map), model_shortcut_name, model_type)) print(
" Converting checkpoint {}/{}: {} - model_type {}".format(
i, len(aws_config_map), model_shortcut_name, model_type
)
)
print("-" * 100) print("-" * 100)
if config_shortcut_name in aws_config_map: if config_shortcut_name in aws_config_map:
...@@ -190,13 +395,15 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc ...@@ -190,13 +395,15 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
model_file = cached_path(model_shortcut_name, force_download=not use_cached_models) model_file = cached_path(model_shortcut_name, force_download=not use_cached_models)
if os.path.isfile(model_shortcut_name): if os.path.isfile(model_shortcut_name):
model_shortcut_name = 'converted_model' model_shortcut_name = "converted_model"
convert_pt_checkpoint_to_tf(model_type=model_type, convert_pt_checkpoint_to_tf(
pytorch_checkpoint_path=model_file, model_type=model_type,
config_file=config_file, pytorch_checkpoint_path=model_file,
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'), config_file=config_file,
compare_with_pt_model=compare_with_pt_model) tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"),
compare_with_pt_model=compare_with_pt_model,
)
if remove_cached_files: if remove_cached_files:
os.remove(config_file) os.remove(config_file)
os.remove(model_file) os.remove(model_file)
...@@ -205,39 +412,47 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc ...@@ -205,39 +412,47 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
parser.add_argument("--tf_dump_path", parser.add_argument(
default = None, "--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
type = str, )
required = True, parser.add_argument(
help = "Path to the output Tensorflow dump file.") "--model_type",
parser.add_argument("--model_type", default=None,
default = None, type=str,
type = str, help="Model type selected in the list of {}. If not given, will download and convert all the models from AWS.".format(
help = "Model type selected in the list of {}. If not given, will download and convert all the models from AWS.".format(list(MODEL_CLASSES.keys()))) list(MODEL_CLASSES.keys())
parser.add_argument("--pytorch_checkpoint_path", ),
default = None, )
type = str, parser.add_argument(
help = "Path to the PyTorch checkpoint path or shortcut name to download from AWS. " "--pytorch_checkpoint_path",
"If not given, will download and convert all the checkpoints from AWS.") default=None,
parser.add_argument("--config_file", type=str,
default = None, help="Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
type = str, "If not given, will download and convert all the checkpoints from AWS.",
help = "The config json file corresponding to the pre-trained model. \n" )
"This specifies the model architecture. If not given and " parser.add_argument(
"--pytorch_checkpoint_path is not given or is a shortcut name" "--config_file",
"use the configuration associated to the shortcut name on the AWS") default=None,
parser.add_argument("--compare_with_pt_model", type=str,
action='store_true', help="The config json file corresponding to the pre-trained model. \n"
help = "Compare Tensorflow and PyTorch model predictions.") "This specifies the model architecture. If not given and "
parser.add_argument("--use_cached_models", "--pytorch_checkpoint_path is not given or is a shortcut name"
action='store_true', "use the configuration associated to the shortcut name on the AWS",
help = "Use cached models if possible instead of updating to latest checkpoint versions.") )
parser.add_argument("--remove_cached_files", parser.add_argument(
action='store_true', "--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions."
help = "Remove pytorch models after conversion (save memory when converting in batches).") )
parser.add_argument("--only_convert_finetuned_models", parser.add_argument(
action='store_true', "--use_cached_models",
help = "Only convert finetuned models.") action="store_true",
help="Use cached models if possible instead of updating to latest checkpoint versions.",
)
parser.add_argument(
"--remove_cached_files",
action="store_true",
help="Remove pytorch models after conversion (save memory when converting in batches).",
)
parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.")
args = parser.parse_args() args = parser.parse_args()
# if args.pytorch_checkpoint_path is not None: # if args.pytorch_checkpoint_path is not None:
...@@ -248,11 +463,15 @@ if __name__ == "__main__": ...@@ -248,11 +463,15 @@ if __name__ == "__main__":
# compare_with_pt_model=args.compare_with_pt_model, # compare_with_pt_model=args.compare_with_pt_model,
# use_cached_models=args.use_cached_models) # use_cached_models=args.use_cached_models)
# else: # else:
convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None, convert_all_pt_checkpoints_to_tf(
args.tf_dump_path, args.model_type.lower() if args.model_type is not None else None,
model_shortcut_names_or_path=[args.pytorch_checkpoint_path] if args.pytorch_checkpoint_path is not None else None, args.tf_dump_path,
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None, model_shortcut_names_or_path=[args.pytorch_checkpoint_path]
compare_with_pt_model=args.compare_with_pt_model, if args.pytorch_checkpoint_path is not None
use_cached_models=args.use_cached_models, else None,
remove_cached_files=args.remove_cached_files, config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
only_convert_finetuned_models=args.only_convert_finetuned_models) compare_with_pt_model=args.compare_with_pt_model,
use_cached_models=args.use_cached_models,
remove_cached_files=args.remove_cached_files,
only_convert_finetuned_models=args.only_convert_finetuned_models,
)
...@@ -30,20 +30,27 @@ if version.parse(fairseq.__version__) < version.parse("0.9.0"): ...@@ -30,20 +30,27 @@ if version.parse(fairseq.__version__) < version.parse("0.9.0"):
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer from fairseq.modules import TransformerSentenceEncoderLayer
from transformers.modeling_bert import (BertConfig, BertEncoder, from transformers.modeling_bert import (
BertIntermediate, BertLayer, BertConfig,
BertModel, BertOutput, BertEncoder,
BertSelfAttention, BertIntermediate,
BertSelfOutput) BertLayer,
from transformers.modeling_roberta import (RobertaEmbeddings, BertModel,
RobertaForMaskedLM, BertOutput,
RobertaForSequenceClassification, BertSelfAttention,
RobertaModel) BertSelfOutput,
)
from transformers.modeling_roberta import (
RobertaEmbeddings,
RobertaForMaskedLM,
RobertaForSequenceClassification,
RobertaModel,
)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SAMPLE_TEXT = 'Hello world! cécé herlolip' SAMPLE_TEXT = "Hello world! cécé herlolip"
def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path, classification_head): def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path, classification_head):
...@@ -61,7 +68,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -61,7 +68,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
intermediate_size=roberta.args.encoder_ffn_embed_dim, intermediate_size=roberta.args.encoder_ffn_embed_dim,
max_position_embeddings=514, max_position_embeddings=514,
type_vocab_size=1, type_vocab_size=1,
layer_norm_eps=1e-5, # PyTorch default used in fairseq layer_norm_eps=1e-5, # PyTorch default used in fairseq
) )
if classification_head: if classification_head:
config.num_labels = roberta.args.num_classes config.num_labels = roberta.args.num_classes
...@@ -74,7 +81,9 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -74,7 +81,9 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
# Embeddings # Embeddings
model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight
model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight
model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them. model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(
model.roberta.embeddings.token_type_embeddings.weight
) # just zero them out b/c RoBERTa doesn't use them.
model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight
model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias
...@@ -85,11 +94,11 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -85,11 +94,11 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
### self attention ### self attention
self_attn: BertSelfAttention = layer.attention.self self_attn: BertSelfAttention = layer.attention.self
assert( assert (
roberta_layer.self_attn.k_proj.weight.data.shape == \ roberta_layer.self_attn.k_proj.weight.data.shape
roberta_layer.self_attn.q_proj.weight.data.shape == \ == roberta_layer.self_attn.q_proj.weight.data.shape
roberta_layer.self_attn.v_proj.weight.data.shape == \ == roberta_layer.self_attn.v_proj.weight.data.shape
torch.Size((config.hidden_size, config.hidden_size)) == torch.Size((config.hidden_size, config.hidden_size))
) )
self_attn.query.weight.data = roberta_layer.self_attn.q_proj.weight self_attn.query.weight.data = roberta_layer.self_attn.q_proj.weight
...@@ -101,9 +110,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -101,9 +110,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
### self-attention output ### self-attention output
self_output: BertSelfOutput = layer.attention.output self_output: BertSelfOutput = layer.attention.output
assert( assert self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
)
self_output.dense.weight = roberta_layer.self_attn.out_proj.weight self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
self_output.dense.bias = roberta_layer.self_attn.out_proj.bias self_output.dense.bias = roberta_layer.self_attn.out_proj.bias
self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
...@@ -111,28 +118,24 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -111,28 +118,24 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
### intermediate ### intermediate
intermediate: BertIntermediate = layer.intermediate intermediate: BertIntermediate = layer.intermediate
assert( assert intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
)
intermediate.dense.weight = roberta_layer.fc1.weight intermediate.dense.weight = roberta_layer.fc1.weight
intermediate.dense.bias = roberta_layer.fc1.bias intermediate.dense.bias = roberta_layer.fc1.bias
### output ### output
bert_output: BertOutput = layer.output bert_output: BertOutput = layer.output
assert( assert bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
)
bert_output.dense.weight = roberta_layer.fc2.weight bert_output.dense.weight = roberta_layer.fc2.weight
bert_output.dense.bias = roberta_layer.fc2.bias bert_output.dense.bias = roberta_layer.fc2.bias
bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias
#### end of layer #### end of layer
if classification_head: if classification_head:
model.classifier.dense.weight = roberta.model.classification_heads['mnli'].dense.weight model.classifier.dense.weight = roberta.model.classification_heads["mnli"].dense.weight
model.classifier.dense.bias = roberta.model.classification_heads['mnli'].dense.bias model.classifier.dense.bias = roberta.model.classification_heads["mnli"].dense.bias
model.classifier.out_proj.weight = roberta.model.classification_heads['mnli'].out_proj.weight model.classifier.out_proj.weight = roberta.model.classification_heads["mnli"].out_proj.weight
model.classifier.out_proj.bias = roberta.model.classification_heads['mnli'].out_proj.bias model.classifier.out_proj.bias = roberta.model.classification_heads["mnli"].out_proj.bias
else: else:
# LM Head # LM Head
model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight
...@@ -143,21 +146,18 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -143,21 +146,18 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
model.lm_head.bias = roberta.model.decoder.lm_head.bias model.lm_head.bias = roberta.model.decoder.lm_head.bias
# Let's check that we get the same results. # Let's check that we get the same results.
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1 input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
our_output = model(input_ids)[0] our_output = model(input_ids)[0]
if classification_head: if classification_head:
their_output = roberta.model.classification_heads['mnli'](roberta.extract_features(input_ids)) their_output = roberta.model.classification_heads["mnli"](roberta.extract_features(input_ids))
else: else:
their_output = roberta.model(input_ids)[0] their_output = roberta.model(input_ids)[0]
print(our_output.shape, their_output.shape) print(our_output.shape, their_output.shape)
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item() max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7 print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
success = torch.allclose(our_output, their_output, atol=1e-3) success = torch.allclose(our_output, their_output, atol=1e-3)
print( print("Do both models output the same tensors?", "🔥" if success else "💩")
"Do both models output the same tensors?",
"🔥" if success else "💩"
)
if not success: if not success:
raise Exception("Something went wRoNg") raise Exception("Something went wRoNg")
...@@ -169,23 +169,16 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -169,23 +169,16 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
parser.add_argument("--roberta_checkpoint_path", parser.add_argument(
default = None, "--roberta_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
type = str, )
required = True, parser.add_argument(
help = "Path the official PyTorch dump.") "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
parser.add_argument("--pytorch_dump_folder_path", )
default = None, parser.add_argument(
type = str, "--classification_head", action="store_true", help="Whether to convert a final classification head."
required = True, )
help = "Path to the output PyTorch model.")
parser.add_argument("--classification_head",
action = "store_true",
help = "Whether to convert a final classification head.")
args = parser.parse_args() args = parser.parse_args()
convert_roberta_checkpoint_to_pytorch( convert_roberta_checkpoint_to_pytorch(
args.roberta_checkpoint_path, args.roberta_checkpoint_path, args.pytorch_dump_folder_path, args.classification_head
args.pytorch_dump_folder_path,
args.classification_head
) )
...@@ -24,8 +24,10 @@ import torch ...@@ -24,8 +24,10 @@ import torch
from transformers import T5Config, T5Model, load_tf_weights_in_t5 from transformers import T5Config, T5Model, load_tf_weights_in_t5
import logging import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
# Initialise PyTorch model # Initialise PyTorch model
config = T5Config.from_json_file(config_file) config = T5Config.from_json_file(config_file)
...@@ -43,23 +45,19 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du ...@@ -43,23 +45,19 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
parser.add_argument("--tf_checkpoint_path", parser.add_argument(
default = None, "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
type = str, )
required = True, parser.add_argument(
help = "Path to the TensorFlow checkpoint path.") "--config_file",
parser.add_argument("--config_file", default=None,
default = None, type=str,
type = str, required=True,
required = True, help="The config json file corresponding to the pre-trained T5 model. \n"
help = "The config json file corresponding to the pre-trained T5 model. \n" "This specifies the model architecture.",
"This specifies the model architecture.") )
parser.add_argument("--pytorch_dump_path", parser.add_argument(
default = None, "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
type = str, )
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args() args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
args.config_file,
args.pytorch_dump_path)
...@@ -26,9 +26,8 @@ import torch ...@@ -26,9 +26,8 @@ import torch
import transformers.tokenization_transfo_xl as data_utils import transformers.tokenization_transfo_xl as data_utils
from transformers import CONFIG_NAME, WEIGHTS_NAME from transformers import CONFIG_NAME, WEIGHTS_NAME
from transformers import (TransfoXLConfig, TransfoXLLMHeadModel, from transformers import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl
load_tf_weights_in_transfo_xl) from transformers.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
from transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES)
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import cPickle as pickle import cPickle as pickle
...@@ -36,32 +35,33 @@ else: ...@@ -36,32 +35,33 @@ else:
import pickle import pickle
import logging import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# We do this to be able to load python 2 datasets pickles # We do this to be able to load python 2 datasets pickles
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
data_utils.Vocab = data_utils.TransfoXLTokenizer data_utils.Vocab = data_utils.TransfoXLTokenizer
data_utils.Corpus = data_utils.TransfoXLCorpus data_utils.Corpus = data_utils.TransfoXLCorpus
sys.modules['data_utils'] = data_utils sys.modules["data_utils"] = data_utils
sys.modules['vocabulary'] = data_utils sys.modules["vocabulary"] = data_utils
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, def convert_transfo_xl_checkpoint_to_pytorch(
transfo_xl_config_file, tf_checkpoint_path, transfo_xl_config_file, pytorch_dump_folder_path, transfo_xl_dataset_file
pytorch_dump_folder_path, ):
transfo_xl_dataset_file):
if transfo_xl_dataset_file: if transfo_xl_dataset_file:
# Convert a pre-processed corpus (see original TensorFlow repo) # Convert a pre-processed corpus (see original TensorFlow repo)
with open(transfo_xl_dataset_file, "rb") as fp: with open(transfo_xl_dataset_file, "rb") as fp:
corpus = pickle.load(fp, encoding="latin1") corpus = pickle.load(fp, encoding="latin1")
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file'] pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["pretrained_vocab_file"]
print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
corpus_vocab_dict = corpus.vocab.__dict__ corpus_vocab_dict = corpus.vocab.__dict__
torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
corpus_dict_no_vocab = corpus.__dict__ corpus_dict_no_vocab = corpus.__dict__
corpus_dict_no_vocab.pop('vocab', None) corpus_dict_no_vocab.pop("vocab", None)
pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME pytorch_dataset_dump_path = pytorch_dump_folder_path + "/" + CORPUS_NAME
print("Save dataset to {}".format(pytorch_dataset_dump_path)) print("Save dataset to {}".format(pytorch_dataset_dump_path))
torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
...@@ -92,26 +92,36 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, ...@@ -92,26 +92,36 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", parser.add_argument(
default = None, "--pytorch_dump_folder_path",
type = str, default=None,
required = True, type=str,
help = "Path to the folder to store the PyTorch model or dataset/vocab.") required=True,
parser.add_argument("--tf_checkpoint_path", help="Path to the folder to store the PyTorch model or dataset/vocab.",
default = "", )
type = str, parser.add_argument(
help = "An optional path to a TensorFlow checkpoint path to be converted.") "--tf_checkpoint_path",
parser.add_argument("--transfo_xl_config_file", default="",
default = "", type=str,
type = str, help="An optional path to a TensorFlow checkpoint path to be converted.",
help = "An optional config json file corresponding to the pre-trained BERT model. \n" )
"This specifies the model architecture.") parser.add_argument(
parser.add_argument("--transfo_xl_dataset_file", "--transfo_xl_config_file",
default = "", default="",
type = str, type=str,
help = "An optional dataset file to be converted in a vocabulary.") help="An optional config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.",
)
parser.add_argument(
"--transfo_xl_dataset_file",
default="",
type=str,
help="An optional dataset file to be converted in a vocabulary.",
)
args = parser.parse_args() args = parser.parse_args()
convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, convert_transfo_xl_checkpoint_to_pytorch(
args.transfo_xl_config_file, args.tf_checkpoint_path,
args.pytorch_dump_folder_path, args.transfo_xl_config_file,
args.transfo_xl_dataset_file) args.pytorch_dump_folder_path,
args.transfo_xl_dataset_file,
)
...@@ -27,32 +27,34 @@ from transformers import CONFIG_NAME, WEIGHTS_NAME ...@@ -27,32 +27,34 @@ from transformers import CONFIG_NAME, WEIGHTS_NAME
from transformers.tokenization_xlm import VOCAB_FILES_NAMES from transformers.tokenization_xlm import VOCAB_FILES_NAMES
import logging import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
# Load checkpoint # Load checkpoint
chkpt = torch.load(xlm_checkpoint_path, map_location='cpu') chkpt = torch.load(xlm_checkpoint_path, map_location="cpu")
state_dict = chkpt['model'] state_dict = chkpt["model"]
# We have the base model one level deeper than the original XLM repository # We have the base model one level deeper than the original XLM repository
two_levels_state_dict = {} two_levels_state_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
if 'pred_layer' in k: if "pred_layer" in k:
two_levels_state_dict[k] = v two_levels_state_dict[k] = v
else: else:
two_levels_state_dict['transformer.' + k] = v two_levels_state_dict["transformer." + k] = v
config = chkpt['params'] config = chkpt["params"]
config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray))) config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray)))
vocab = chkpt['dico_word2id'] vocab = chkpt["dico_word2id"]
vocab = dict((s + '</w>' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in vocab.items()) vocab = dict((s + "</w>" if s.find("@@") == -1 and i > 13 else s.replace("@@", ""), i) for s, i in vocab.items())
# Save pytorch-model # Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file'] pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["vocab_file"]
print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(two_levels_state_dict, pytorch_weights_dump_path) torch.save(two_levels_state_dict, pytorch_weights_dump_path)
...@@ -69,15 +71,11 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p ...@@ -69,15 +71,11 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
parser.add_argument("--xlm_checkpoint_path", parser.add_argument(
default = None, "--xlm_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
type = str, )
required = True, parser.add_argument(
help = "Path the official PyTorch dump.") "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
parser.add_argument("--pytorch_dump_folder_path", )
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args() args = parser.parse_args()
convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path) convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path)
...@@ -22,11 +22,15 @@ import os ...@@ -22,11 +22,15 @@ import os
import argparse import argparse
import torch import torch
from transformers import (CONFIG_NAME, WEIGHTS_NAME, from transformers import (
XLNetConfig, CONFIG_NAME,
XLNetLMHeadModel, XLNetForQuestionAnswering, WEIGHTS_NAME,
XLNetForSequenceClassification, XLNetConfig,
load_tf_weights_in_xlnet) XLNetLMHeadModel,
XLNetForQuestionAnswering,
XLNetForSequenceClassification,
load_tf_weights_in_xlnet,
)
GLUE_TASKS_NUM_LABELS = { GLUE_TASKS_NUM_LABELS = {
"cola": 2, "cola": 2,
...@@ -41,9 +45,13 @@ GLUE_TASKS_NUM_LABELS = { ...@@ -41,9 +45,13 @@ GLUE_TASKS_NUM_LABELS = {
} }
import logging import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None):
def convert_xlnet_checkpoint_to_pytorch(
tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None
):
# Initialise PyTorch model # Initialise PyTorch model
config = XLNetConfig.from_json_file(bert_config_file) config = XLNetConfig.from_json_file(bert_config_file)
...@@ -53,7 +61,7 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py ...@@ -53,7 +61,7 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
config.finetuning_task = finetuning_task config.finetuning_task = finetuning_task
config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task] config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task]
model = XLNetForSequenceClassification(config) model = XLNetForSequenceClassification(config)
elif 'squad' in finetuning_task: elif "squad" in finetuning_task:
config.finetuning_task = finetuning_task config.finetuning_task = finetuning_task
model = XLNetForQuestionAnswering(config) model = XLNetForQuestionAnswering(config)
else: else:
...@@ -75,30 +83,33 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py ...@@ -75,30 +83,33 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
parser.add_argument("--tf_checkpoint_path", parser.add_argument(
default = None, "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
type = str, )
required = True, parser.add_argument(
help = "Path to the TensorFlow checkpoint path.") "--xlnet_config_file",
parser.add_argument("--xlnet_config_file", default=None,
default = None, type=str,
type = str, required=True,
required = True, help="The config json file corresponding to the pre-trained XLNet model. \n"
help = "The config json file corresponding to the pre-trained XLNet model. \n" "This specifies the model architecture.",
"This specifies the model architecture.") )
parser.add_argument("--pytorch_dump_folder_path", parser.add_argument(
default = None, "--pytorch_dump_folder_path",
type = str, default=None,
required = True, type=str,
help = "Path to the folder to store the PyTorch model or dataset/vocab.") required=True,
parser.add_argument("--finetuning_task", help="Path to the folder to store the PyTorch model or dataset/vocab.",
default = None, )
type = str, parser.add_argument(
help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned") "--finetuning_task",
default=None,
type=str,
help="Name of a task on which the XLNet TensorFloaw model was fine-tuned",
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path, convert_xlnet_checkpoint_to_pytorch(
args.xlnet_config_file, args.tf_checkpoint_path, args.xlnet_config_file, args.pytorch_dump_folder_path, args.finetuning_task
args.pytorch_dump_folder_path, )
args.finetuning_task)
from .processors import InputExample, InputFeatures, DataProcessor, SquadFeatures, SingleSentenceClassificationProcessor from .processors import (
InputExample,
InputFeatures,
DataProcessor,
SquadFeatures,
SingleSentenceClassificationProcessor,
)
from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
from .processors import squad_convert_examples_to_features, SquadExample, SquadV1Processor, SquadV2Processor from .processors import squad_convert_examples_to_features, SquadExample, SquadV1Processor, SquadV2Processor
from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_labels from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
from .metrics import is_sklearn_available from .metrics import is_sklearn_available
if is_sklearn_available(): if is_sklearn_available():
from .metrics import glue_compute_metrics, xnli_compute_metrics from .metrics import glue_compute_metrics, xnli_compute_metrics
...@@ -23,20 +23,22 @@ logger = logging.getLogger(__name__) ...@@ -23,20 +23,22 @@ logger = logging.getLogger(__name__)
try: try:
from scipy.stats import pearsonr, spearmanr from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score from sklearn.metrics import matthews_corrcoef, f1_score
_has_sklearn = True _has_sklearn = True
except (AttributeError, ImportError) as e: except (AttributeError, ImportError) as e:
logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html") logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html")
_has_sklearn = False _has_sklearn = False
def is_sklearn_available(): def is_sklearn_available():
return _has_sklearn return _has_sklearn
if _has_sklearn: if _has_sklearn:
def simple_accuracy(preds, labels): def simple_accuracy(preds, labels):
return (preds == labels).mean() return (preds == labels).mean()
def acc_and_f1(preds, labels): def acc_and_f1(preds, labels):
acc = simple_accuracy(preds, labels) acc = simple_accuracy(preds, labels)
f1 = f1_score(y_true=labels, y_pred=preds) f1 = f1_score(y_true=labels, y_pred=preds)
...@@ -46,7 +48,6 @@ if _has_sklearn: ...@@ -46,7 +48,6 @@ if _has_sklearn:
"acc_and_f1": (acc + f1) / 2, "acc_and_f1": (acc + f1) / 2,
} }
def pearson_and_spearman(preds, labels): def pearson_and_spearman(preds, labels):
pearson_corr = pearsonr(preds, labels)[0] pearson_corr = pearsonr(preds, labels)[0]
spearman_corr = spearmanr(preds, labels)[0] spearman_corr = spearmanr(preds, labels)[0]
...@@ -56,7 +57,6 @@ if _has_sklearn: ...@@ -56,7 +57,6 @@ if _has_sklearn:
"corr": (pearson_corr + spearman_corr) / 2, "corr": (pearson_corr + spearman_corr) / 2,
} }
def glue_compute_metrics(task_name, preds, labels): def glue_compute_metrics(task_name, preds, labels):
assert len(preds) == len(labels) assert len(preds) == len(labels)
if task_name == "cola": if task_name == "cola":
...@@ -82,7 +82,6 @@ if _has_sklearn: ...@@ -82,7 +82,6 @@ if _has_sklearn:
else: else:
raise KeyError(task_name) raise KeyError(task_name)
def xnli_compute_metrics(task_name, preds, labels): def xnli_compute_metrics(task_name, preds, labels):
assert len(preds) == len(labels) assert len(preds) == len(labels)
if task_name == "xnli": if task_name == "xnli":
......
...@@ -24,19 +24,21 @@ logger = logging.getLogger(__name__) ...@@ -24,19 +24,21 @@ logger = logging.getLogger(__name__)
def normalize_answer(s): def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace.""" """Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text): def remove_articles(text):
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, ' ', text) return re.sub(regex, " ", text)
def white_space_fix(text): def white_space_fix(text):
return ' '.join(text.split()) return " ".join(text.split())
def remove_punc(text): def remove_punc(text):
exclude = set(string.punctuation) exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude) return "".join(ch for ch in text if ch not in exclude)
def lower(text): def lower(text):
return text.lower() return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s)))) return white_space_fix(remove_articles(remove_punc(lower(s))))
...@@ -75,14 +77,14 @@ def get_raw_scores(examples, preds): ...@@ -75,14 +77,14 @@ def get_raw_scores(examples, preds):
for example in examples: for example in examples:
qas_id = example.qas_id qas_id = example.qas_id
gold_answers = [answer['text'] for answer in example.answers if normalize_answer(answer['text'])] gold_answers = [answer["text"] for answer in example.answers if normalize_answer(answer["text"])]
if not gold_answers: if not gold_answers:
# For unanswerable questions, only correct answer is empty string # For unanswerable questions, only correct answer is empty string
gold_answers = [''] gold_answers = [""]
if qas_id not in preds: if qas_id not in preds:
print('Missing prediction for %s' % qas_id) print("Missing prediction for %s" % qas_id)
continue continue
prediction = preds[qas_id] prediction = preds[qas_id]
...@@ -106,23 +108,27 @@ def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): ...@@ -106,23 +108,27 @@ def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
def make_eval_dict(exact_scores, f1_scores, qid_list=None): def make_eval_dict(exact_scores, f1_scores, qid_list=None):
if not qid_list: if not qid_list:
total = len(exact_scores) total = len(exact_scores)
return collections.OrderedDict([ return collections.OrderedDict(
('exact', 100.0 * sum(exact_scores.values()) / total), [
('f1', 100.0 * sum(f1_scores.values()) / total), ("exact", 100.0 * sum(exact_scores.values()) / total),
('total', total), ("f1", 100.0 * sum(f1_scores.values()) / total),
]) ("total", total),
]
)
else: else:
total = len(qid_list) total = len(qid_list)
return collections.OrderedDict([ return collections.OrderedDict(
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total), [
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total), ("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
('total', total), ("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
]) ("total", total),
]
)
def merge_eval(main_eval, new_eval, prefix): def merge_eval(main_eval, new_eval, prefix):
for k in new_eval: for k in new_eval:
main_eval['%s_%s' % (prefix, k)] = new_eval[k] main_eval["%s_%s" % (prefix, k)] = new_eval[k]
def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans): def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
...@@ -160,16 +166,14 @@ def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans): ...@@ -160,16 +166,14 @@ def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2( best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
preds, exact_raw, na_probs, qid_to_has_ans) best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2( main_eval["best_exact"] = best_exact
preds, f1_raw, na_probs, qid_to_has_ans) main_eval["best_exact_thresh"] = exact_thresh
main_eval['best_exact'] = best_exact main_eval["best_f1"] = best_f1
main_eval['best_exact_thresh'] = exact_thresh main_eval["best_f1_thresh"] = f1_thresh
main_eval['best_f1'] = best_f1 main_eval["has_ans_exact"] = has_ans_exact
main_eval['best_f1_thresh'] = f1_thresh main_eval["has_ans_f1"] = has_ans_f1
main_eval['has_ans_exact'] = has_ans_exact
main_eval['has_ans_f1'] = has_ans_f1
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
...@@ -199,10 +203,10 @@ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_h ...@@ -199,10 +203,10 @@ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_h
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
main_eval['best_exact'] = best_exact main_eval["best_exact"] = best_exact
main_eval['best_exact_thresh'] = exact_thresh main_eval["best_exact_thresh"] = exact_thresh
main_eval['best_f1'] = best_f1 main_eval["best_f1"] = best_f1
main_eval['best_f1_thresh'] = f1_thresh main_eval["best_f1_thresh"] = f1_thresh
def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0): def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0):
...@@ -215,18 +219,20 @@ def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_ ...@@ -215,18 +219,20 @@ def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_
exact, f1 = get_raw_scores(examples, preds) exact, f1 = get_raw_scores(examples, preds)
exact_threshold = apply_no_ans_threshold(exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold) exact_threshold = apply_no_ans_threshold(
exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold
)
f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold) f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
evaluation = make_eval_dict(exact_threshold, f1_threshold) evaluation = make_eval_dict(exact_threshold, f1_threshold)
if has_answer_qids: if has_answer_qids:
has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids) has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids)
merge_eval(evaluation, has_ans_eval, 'HasAns') merge_eval(evaluation, has_ans_eval, "HasAns")
if no_answer_qids: if no_answer_qids:
no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids) no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids)
merge_eval(evaluation, no_ans_eval, 'NoAns') merge_eval(evaluation, no_ans_eval, "NoAns")
if no_answer_probs: if no_answer_probs:
find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer) find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer)
...@@ -284,8 +290,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): ...@@ -284,8 +290,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
start_position = tok_text.find(pred_text) start_position = tok_text.find(pred_text)
if start_position == -1: if start_position == -1:
if verbose_logging: if verbose_logging:
logger.info( logger.info("Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
return orig_text return orig_text
end_position = start_position + len(pred_text) - 1 end_position = start_position + len(pred_text) - 1
...@@ -294,8 +299,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): ...@@ -294,8 +299,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
if len(orig_ns_text) != len(tok_ns_text): if len(orig_ns_text) != len(tok_ns_text):
if verbose_logging: if verbose_logging:
logger.info("Length not equal after stripping spaces: '%s' vs '%s'", logger.info("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text)
orig_ns_text, tok_ns_text)
return orig_text return orig_text
# We then project the characters in `pred_text` back to `orig_text` using # We then project the characters in `pred_text` back to `orig_text` using
...@@ -326,7 +330,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): ...@@ -326,7 +330,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
logger.info("Couldn't map end position") logger.info("Couldn't map end position")
return orig_text return orig_text
output_text = orig_text[orig_start_position:(orig_end_position + 1)] output_text = orig_text[orig_start_position : (orig_end_position + 1)]
return output_text return output_text
...@@ -393,8 +397,8 @@ def compute_predictions_logits( ...@@ -393,8 +397,8 @@ def compute_predictions_logits(
unique_id_to_result[result.unique_id] = result unique_id_to_result[result.unique_id] = result
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
"PrelimPrediction", "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]
["feature_index", "start_index", "end_index", "start_logit", "end_logit"]) )
all_predictions = collections.OrderedDict() all_predictions = collections.OrderedDict()
all_nbest_json = collections.OrderedDict() all_nbest_json = collections.OrderedDict()
...@@ -447,7 +451,9 @@ def compute_predictions_logits( ...@@ -447,7 +451,9 @@ def compute_predictions_logits(
start_index=start_index, start_index=start_index,
end_index=end_index, end_index=end_index,
start_logit=result.start_logits[start_index], start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index])) end_logit=result.end_logits[end_index],
)
)
if version_2_with_negative: if version_2_with_negative:
prelim_predictions.append( prelim_predictions.append(
_PrelimPrediction( _PrelimPrediction(
...@@ -455,14 +461,14 @@ def compute_predictions_logits( ...@@ -455,14 +461,14 @@ def compute_predictions_logits(
start_index=0, start_index=0,
end_index=0, end_index=0,
start_logit=null_start_logit, start_logit=null_start_logit,
end_logit=null_end_logit)) end_logit=null_end_logit,
prelim_predictions = sorted( )
prelim_predictions, )
key=lambda x: (x.start_logit + x.end_logit), prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
reverse=True)
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
"NbestPrediction", ["text", "start_logit", "end_logit"]) "NbestPrediction", ["text", "start_logit", "end_logit"]
)
seen_predictions = {} seen_predictions = {}
nbest = [] nbest = []
...@@ -471,10 +477,10 @@ def compute_predictions_logits( ...@@ -471,10 +477,10 @@ def compute_predictions_logits(
break break
feature = features[pred.feature_index] feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction if pred.start_index > 0: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index] orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index] orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
tok_text = tokenizer.convert_tokens_to_string(tok_tokens) tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
...@@ -498,31 +504,21 @@ def compute_predictions_logits( ...@@ -498,31 +504,21 @@ def compute_predictions_logits(
final_text = "" final_text = ""
seen_predictions[final_text] = True seen_predictions[final_text] = True
nbest.append( nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit))
_NbestPrediction(
text=final_text,
start_logit=pred.start_logit,
end_logit=pred.end_logit))
# if we didn't include the empty option in the n-best, include it # if we didn't include the empty option in the n-best, include it
if version_2_with_negative: if version_2_with_negative:
if "" not in seen_predictions: if "" not in seen_predictions:
nbest.append( nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit))
_NbestPrediction(
text="",
start_logit=null_start_logit,
end_logit=null_end_logit))
# In very rare edge cases we could only have single null prediction. # In very rare edge cases we could only have single null prediction.
# So we just create a nonce prediction in this case to avoid failure. # So we just create a nonce prediction in this case to avoid failure.
if len(nbest) == 1: if len(nbest) == 1:
nbest.insert(0, nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
# In very rare edge cases we could have no valid predictions. So we # In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure. # just create a nonce prediction in this case to avoid failure.
if not nbest: if not nbest:
nbest.append( nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
assert len(nbest) >= 1 assert len(nbest) >= 1
...@@ -551,8 +547,7 @@ def compute_predictions_logits( ...@@ -551,8 +547,7 @@ def compute_predictions_logits(
all_predictions[example.qas_id] = nbest_json[0]["text"] all_predictions[example.qas_id] = nbest_json[0]["text"]
else: else:
# predict "" iff the null score - the score of best non-null > threshold # predict "" iff the null score - the score of best non-null > threshold
score_diff = score_null - best_non_null_entry.start_logit - ( score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit)
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold: if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = "" all_predictions[example.qas_id] = ""
...@@ -586,7 +581,7 @@ def compute_predictions_log_probs( ...@@ -586,7 +581,7 @@ def compute_predictions_log_probs(
end_n_top, end_n_top,
version_2_with_negative, version_2_with_negative,
tokenizer, tokenizer,
verbose_logging verbose_logging,
): ):
""" XLNet write prediction logic (more complex than Bert's). """ XLNet write prediction logic (more complex than Bert's).
Write final predictions to the json file and log-odds of null if needed. Write final predictions to the json file and log-odds of null if needed.
...@@ -594,12 +589,12 @@ def compute_predictions_log_probs( ...@@ -594,12 +589,12 @@ def compute_predictions_log_probs(
Requires utils_squad_evaluate.py Requires utils_squad_evaluate.py
""" """
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
"PrelimPrediction", "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_log_prob", "end_log_prob"]
["feature_index", "start_index", "end_index", )
"start_log_prob", "end_log_prob"])
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"]) "NbestPrediction", ["text", "start_log_prob", "end_log_prob"]
)
logger.info("Writing predictions to: %s", output_prediction_file) logger.info("Writing predictions to: %s", output_prediction_file)
# logger.info("Writing nbest to: %s" % (output_nbest_file)) # logger.info("Writing nbest to: %s" % (output_nbest_file))
...@@ -663,12 +658,13 @@ def compute_predictions_log_probs( ...@@ -663,12 +658,13 @@ def compute_predictions_log_probs(
start_index=start_index, start_index=start_index,
end_index=end_index, end_index=end_index,
start_log_prob=start_log_prob, start_log_prob=start_log_prob,
end_log_prob=end_log_prob)) end_log_prob=end_log_prob,
)
)
prelim_predictions = sorted( prelim_predictions = sorted(
prelim_predictions, prelim_predictions, key=lambda x: (x.start_log_prob + x.end_log_prob), reverse=True
key=lambda x: (x.start_log_prob + x.end_log_prob), )
reverse=True)
seen_predictions = {} seen_predictions = {}
nbest = [] nbest = []
...@@ -688,10 +684,10 @@ def compute_predictions_log_probs( ...@@ -688,10 +684,10 @@ def compute_predictions_log_probs(
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip() # final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
# Previously used Bert untokenizer # Previously used Bert untokenizer
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index] orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index] orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
tok_text = tokenizer.convert_tokens_to_string(tok_tokens) tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
# Clean whitespace # Clean whitespace
...@@ -704,8 +700,7 @@ def compute_predictions_log_probs( ...@@ -704,8 +700,7 @@ def compute_predictions_log_probs(
else: else:
do_lower_case = tokenizer.do_lowercase_and_remove_accent do_lower_case = tokenizer.do_lowercase_and_remove_accent
final_text = get_final_text(tok_text, orig_text, do_lower_case, final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
verbose_logging)
if final_text in seen_predictions: if final_text in seen_predictions:
continue continue
...@@ -713,17 +708,13 @@ def compute_predictions_log_probs( ...@@ -713,17 +708,13 @@ def compute_predictions_log_probs(
seen_predictions[final_text] = True seen_predictions[final_text] = True
nbest.append( nbest.append(
_NbestPrediction( _NbestPrediction(text=final_text, start_log_prob=pred.start_log_prob, end_log_prob=pred.end_log_prob)
text=final_text, )
start_log_prob=pred.start_log_prob,
end_log_prob=pred.end_log_prob))
# In very rare edge cases we could have no valid predictions. So we # In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure. # just create a nonce prediction in this case to avoid failure.
if not nbest: if not nbest:
nbest.append( nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
_NbestPrediction(text="", start_log_prob=-1e6,
end_log_prob=-1e6))
total_scores = [] total_scores = []
best_non_null_entry = None best_non_null_entry = None
......
from .utils import InputExample, InputFeatures, DataProcessor, SingleSentenceClassificationProcessor from .utils import InputExample, InputFeatures, DataProcessor, SingleSentenceClassificationProcessor
from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
from .squad import squad_convert_examples_to_features, SquadFeatures, SquadExample, SquadV1Processor, SquadV2Processor from .squad import squad_convert_examples_to_features, SquadFeatures, SquadExample, SquadV1Processor, SquadV2Processor
from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
\ No newline at end of file
...@@ -27,15 +27,18 @@ if is_tf_available(): ...@@ -27,15 +27,18 @@ if is_tf_available():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def glue_convert_examples_to_features(examples, tokenizer, def glue_convert_examples_to_features(
max_length=512, examples,
task=None, tokenizer,
label_list=None, max_length=512,
output_mode=None, task=None,
pad_on_left=False, label_list=None,
pad_token=0, output_mode=None,
pad_token_segment_id=0, pad_on_left=False,
mask_padding_with_zero=True): pad_token=0,
pad_token_segment_id=0,
mask_padding_with_zero=True,
):
""" """
Loads a data file into a list of ``InputFeatures`` Loads a data file into a list of ``InputFeatures``
...@@ -82,12 +85,7 @@ def glue_convert_examples_to_features(examples, tokenizer, ...@@ -82,12 +85,7 @@ def glue_convert_examples_to_features(examples, tokenizer,
example = processor.get_example_from_tensor_dict(example) example = processor.get_example_from_tensor_dict(example)
example = processor.tfds_map(example) example = processor.tfds_map(example)
inputs = tokenizer.encode_plus( inputs = tokenizer.encode_plus(example.text_a, example.text_b, add_special_tokens=True, max_length=max_length,)
example.text_a,
example.text_b,
add_special_tokens=True,
max_length=max_length,
)
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
# The mask has 1 for real tokens and 0 for padding tokens. Only real # The mask has 1 for real tokens and 0 for padding tokens. Only real
...@@ -106,8 +104,12 @@ def glue_convert_examples_to_features(examples, tokenizer, ...@@ -106,8 +104,12 @@ def glue_convert_examples_to_features(examples, tokenizer,
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length) assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length) assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(
assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids), max_length) len(attention_mask), max_length
)
assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(
len(token_type_ids), max_length
)
if output_mode == "classification": if output_mode == "classification":
label = label_map[example.label] label = label_map[example.label]
...@@ -125,28 +127,36 @@ def glue_convert_examples_to_features(examples, tokenizer, ...@@ -125,28 +127,36 @@ def glue_convert_examples_to_features(examples, tokenizer,
logger.info("label: %s (id = %d)" % (example.label, label)) logger.info("label: %s (id = %d)" % (example.label, label))
features.append( features.append(
InputFeatures(input_ids=input_ids, InputFeatures(
attention_mask=attention_mask, input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, label=label
token_type_ids=token_type_ids, )
label=label)) )
if is_tf_available() and is_tf_dataset: if is_tf_available() and is_tf_dataset:
def gen(): def gen():
for ex in features: for ex in features:
yield ({'input_ids': ex.input_ids, yield (
'attention_mask': ex.attention_mask, {
'token_type_ids': ex.token_type_ids}, "input_ids": ex.input_ids,
ex.label) "attention_mask": ex.attention_mask,
"token_type_ids": ex.token_type_ids,
return tf.data.Dataset.from_generator(gen, },
({'input_ids': tf.int32, ex.label,
'attention_mask': tf.int32, )
'token_type_ids': tf.int32},
tf.int64), return tf.data.Dataset.from_generator(
({'input_ids': tf.TensorShape([None]), gen,
'attention_mask': tf.TensorShape([None]), ({"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, tf.int64),
'token_type_ids': tf.TensorShape([None])}, (
tf.TensorShape([]))) {
"input_ids": tf.TensorShape([None]),
"attention_mask": tf.TensorShape([None]),
"token_type_ids": tf.TensorShape([None]),
},
tf.TensorShape([]),
),
)
return features return features
...@@ -156,21 +166,21 @@ class MrpcProcessor(DataProcessor): ...@@ -156,21 +166,21 @@ class MrpcProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['sentence1'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
tensor_dict['sentence2'].numpy().decode('utf-8'), tensor_dict["sentence1"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -186,8 +196,7 @@ class MrpcProcessor(DataProcessor): ...@@ -186,8 +196,7 @@ class MrpcProcessor(DataProcessor):
text_a = line[3] text_a = line[3]
text_b = line[4] text_b = line[4]
label = line[0] label = line[0]
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -196,21 +205,20 @@ class MnliProcessor(DataProcessor): ...@@ -196,21 +205,20 @@ class MnliProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['premise'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
tensor_dict['hypothesis'].numpy().decode('utf-8'), tensor_dict["premise"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) tensor_dict["hypothesis"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -226,8 +234,7 @@ class MnliProcessor(DataProcessor): ...@@ -226,8 +234,7 @@ class MnliProcessor(DataProcessor):
text_a = line[8] text_a = line[8]
text_b = line[9] text_b = line[9]
label = line[-1] label = line[-1]
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -236,9 +243,7 @@ class MnliMismatchedProcessor(MnliProcessor): ...@@ -236,9 +243,7 @@ class MnliMismatchedProcessor(MnliProcessor):
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_matched")
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
"dev_matched")
class ColaProcessor(DataProcessor): class ColaProcessor(DataProcessor):
...@@ -246,20 +251,20 @@ class ColaProcessor(DataProcessor): ...@@ -246,20 +251,20 @@ class ColaProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['sentence'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
None, tensor_dict["sentence"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) None,
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -272,8 +277,7 @@ class ColaProcessor(DataProcessor): ...@@ -272,8 +277,7 @@ class ColaProcessor(DataProcessor):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
text_a = line[3] text_a = line[3]
label = line[1] label = line[1]
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples return examples
...@@ -282,20 +286,20 @@ class Sst2Processor(DataProcessor): ...@@ -282,20 +286,20 @@ class Sst2Processor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['sentence'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
None, tensor_dict["sentence"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) None,
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -310,8 +314,7 @@ class Sst2Processor(DataProcessor): ...@@ -310,8 +314,7 @@ class Sst2Processor(DataProcessor):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
text_a = line[0] text_a = line[0]
label = line[1] label = line[1]
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples return examples
...@@ -320,20 +323,20 @@ class StsbProcessor(DataProcessor): ...@@ -320,20 +323,20 @@ class StsbProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['sentence1'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
tensor_dict['sentence2'].numpy().decode('utf-8'), tensor_dict["sentence1"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -349,8 +352,7 @@ class StsbProcessor(DataProcessor): ...@@ -349,8 +352,7 @@ class StsbProcessor(DataProcessor):
text_a = line[7] text_a = line[7]
text_b = line[8] text_b = line[8]
label = line[-1] label = line[-1]
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -359,20 +361,20 @@ class QqpProcessor(DataProcessor): ...@@ -359,20 +361,20 @@ class QqpProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['question1'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
tensor_dict['question2'].numpy().decode('utf-8'), tensor_dict["question1"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) tensor_dict["question2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -391,8 +393,7 @@ class QqpProcessor(DataProcessor): ...@@ -391,8 +393,7 @@ class QqpProcessor(DataProcessor):
label = line[5] label = line[5]
except IndexError: except IndexError:
continue continue
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -401,21 +402,20 @@ class QnliProcessor(DataProcessor): ...@@ -401,21 +402,20 @@ class QnliProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['question'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
tensor_dict['sentence'].numpy().decode('utf-8'), tensor_dict["question"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) tensor_dict["sentence"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
self._read_tsv(os.path.join(data_dir, "dev.tsv")),
"dev_matched")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -431,8 +431,7 @@ class QnliProcessor(DataProcessor): ...@@ -431,8 +431,7 @@ class QnliProcessor(DataProcessor):
text_a = line[1] text_a = line[1]
text_b = line[2] text_b = line[2]
label = line[-1] label = line[-1]
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -441,20 +440,20 @@ class RteProcessor(DataProcessor): ...@@ -441,20 +440,20 @@ class RteProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['sentence1'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
tensor_dict['sentence2'].numpy().decode('utf-8'), tensor_dict["sentence1"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -470,8 +469,7 @@ class RteProcessor(DataProcessor): ...@@ -470,8 +469,7 @@ class RteProcessor(DataProcessor):
text_a = line[1] text_a = line[1]
text_b = line[2] text_b = line[2]
label = line[-1] label = line[-1]
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -480,20 +478,20 @@ class WnliProcessor(DataProcessor): ...@@ -480,20 +478,20 @@ class WnliProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['sentence1'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
tensor_dict['sentence2'].numpy().decode('utf-8'), tensor_dict["sentence1"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -509,10 +507,10 @@ class WnliProcessor(DataProcessor): ...@@ -509,10 +507,10 @@ class WnliProcessor(DataProcessor):
text_a = line[1] text_a = line[1]
text_b = line[2] text_b = line[2]
label = line[-1] label = line[-1]
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
glue_tasks_num_labels = { glue_tasks_num_labels = {
"cola": 2, "cola": 2,
"mnli": 3, "mnli": 3,
......
...@@ -82,8 +82,8 @@ def _is_whitespace(c): ...@@ -82,8 +82,8 @@ def _is_whitespace(c):
return True return True
return False return False
def squad_convert_example_to_features(example, max_seq_length,
doc_stride, max_query_length, is_training): def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_query_length, is_training):
features = [] features = []
if is_training and not example.is_impossible: if is_training and not example.is_impossible:
# Get start and end position # Get start and end position
...@@ -91,7 +91,7 @@ def squad_convert_example_to_features(example, max_seq_length, ...@@ -91,7 +91,7 @@ def squad_convert_example_to_features(example, max_seq_length,
end_position = example.end_position end_position = example.end_position
# If the answer cannot be found in the text, then skip this example. # If the answer cannot be found in the text, then skip this example.
actual_text = " ".join(example.doc_tokens[start_position:(end_position + 1)]) actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)])
cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text)) cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
if actual_text.find(cleaned_answer_text) == -1: if actual_text.find(cleaned_answer_text) == -1:
logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text) logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
...@@ -121,8 +121,11 @@ def squad_convert_example_to_features(example, max_seq_length, ...@@ -121,8 +121,11 @@ def squad_convert_example_to_features(example, max_seq_length,
spans = [] spans = []
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length) truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length)
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence + 1 \ sequence_added_tokens = (
if 'roberta' in str(type(tokenizer)) else tokenizer.max_len - tokenizer.max_len_single_sentence tokenizer.max_len - tokenizer.max_len_single_sentence + 1
if "roberta" in str(type(tokenizer))
else tokenizer.max_len - tokenizer.max_len_single_sentence
)
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
span_doc_tokens = all_doc_tokens span_doc_tokens = all_doc_tokens
...@@ -135,16 +138,18 @@ def squad_convert_example_to_features(example, max_seq_length, ...@@ -135,16 +138,18 @@ def squad_convert_example_to_features(example, max_seq_length,
return_overflowing_tokens=True, return_overflowing_tokens=True,
pad_to_max_length=True, pad_to_max_length=True,
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens, stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
truncation_strategy='only_second' if tokenizer.padding_side == "right" else 'only_first' truncation_strategy="only_second" if tokenizer.padding_side == "right" else "only_first",
) )
paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride, paragraph_len = min(
max_seq_length - len(truncated_query) - sequence_pair_added_tokens) len(all_doc_tokens) - len(spans) * doc_stride,
max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
)
if tokenizer.pad_token_id in encoded_dict['input_ids']: if tokenizer.pad_token_id in encoded_dict["input_ids"]:
non_padded_ids = encoded_dict['input_ids'][:encoded_dict['input_ids'].index(tokenizer.pad_token_id)] non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
else: else:
non_padded_ids = encoded_dict['input_ids'] non_padded_ids = encoded_dict["input_ids"]
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
...@@ -170,17 +175,20 @@ def squad_convert_example_to_features(example, max_seq_length, ...@@ -170,17 +175,20 @@ def squad_convert_example_to_features(example, max_seq_length,
for doc_span_index in range(len(spans)): for doc_span_index in range(len(spans)):
for j in range(spans[doc_span_index]["paragraph_len"]): for j in range(spans[doc_span_index]["paragraph_len"]):
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j) is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
index = j if tokenizer.padding_side == "left" else spans[doc_span_index][ index = (
"truncated_query_with_special_tokens_length"] + j j
if tokenizer.padding_side == "left"
else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
)
spans[doc_span_index]["token_is_max_context"][index] = is_max_context spans[doc_span_index]["token_is_max_context"][index] = is_max_context
for span in spans: for span in spans:
# Identify the position of the CLS token # Identify the position of the CLS token
cls_index = span['input_ids'].index(tokenizer.cls_token_id) cls_index = span["input_ids"].index(tokenizer.cls_token_id)
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# Original TF implem also keep the classification token (set to 0) (not sure why...) # Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask = np.array(span['token_type_ids']) p_mask = np.array(span["token_type_ids"])
p_mask = np.minimum(p_mask, 1) p_mask = np.minimum(p_mask, 1)
...@@ -219,31 +227,34 @@ def squad_convert_example_to_features(example, max_seq_length, ...@@ -219,31 +227,34 @@ def squad_convert_example_to_features(example, max_seq_length,
start_position = tok_start_position - doc_start + doc_offset start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset end_position = tok_end_position - doc_start + doc_offset
features.append(SquadFeatures( features.append(
span['input_ids'], SquadFeatures(
span['attention_mask'], span["input_ids"],
span['token_type_ids'], span["attention_mask"],
cls_index, span["token_type_ids"],
p_mask.tolist(), cls_index,
example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing. p_mask.tolist(),
unique_id=0, example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing.
paragraph_len=span['paragraph_len'], unique_id=0,
token_is_max_context=span["token_is_max_context"], paragraph_len=span["paragraph_len"],
tokens=span["tokens"], token_is_max_context=span["token_is_max_context"],
token_to_orig_map=span["token_to_orig_map"], tokens=span["tokens"],
token_to_orig_map=span["token_to_orig_map"],
start_position=start_position, start_position=start_position,
end_position=end_position end_position=end_position,
)) )
)
return features return features
def squad_convert_example_to_features_init(tokenizer_for_convert): def squad_convert_example_to_features_init(tokenizer_for_convert):
global tokenizer global tokenizer
tokenizer = tokenizer_for_convert tokenizer = tokenizer_for_convert
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
doc_stride, max_query_length, is_training, def squad_convert_examples_to_features(
return_dataset=False, threads=1): examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training, return_dataset=False, threads=1
):
""" """
Converts a list of examples into a list of features that can be directly given as input to a model. Converts a list of examples into a list of features that can be directly given as input to a model.
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs. It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
...@@ -279,17 +290,28 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -279,17 +290,28 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
) )
""" """
# Defining helper methods # Defining helper methods
features = [] features = []
threads = min(threads, cpu_count()) threads = min(threads, cpu_count())
with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p: with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p:
annotate_ = partial(squad_convert_example_to_features, max_seq_length=max_seq_length, annotate_ = partial(
doc_stride=doc_stride, max_query_length=max_query_length, is_training=is_training) squad_convert_example_to_features,
features = list(tqdm(p.imap(annotate_, examples, chunksize=32), total=len(examples), desc='convert squad examples to features')) max_seq_length=max_seq_length,
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=is_training,
)
features = list(
tqdm(
p.imap(annotate_, examples, chunksize=32),
total=len(examples),
desc="convert squad examples to features",
)
)
new_features = [] new_features = []
unique_id = 1000000000 unique_id = 1000000000
example_index = 0 example_index = 0
for example_features in tqdm(features, total=len(features), desc='add example index and unique id'): for example_features in tqdm(features, total=len(features), desc="add example index and unique id"):
if not example_features: if not example_features:
continue continue
for example_feature in example_features: for example_feature in example_features:
...@@ -300,7 +322,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -300,7 +322,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
example_index += 1 example_index += 1
features = new_features features = new_features
del new_features del new_features
if return_dataset == 'pt': if return_dataset == "pt":
if not is_torch_available(): if not is_torch_available():
raise ImportError("Pytorch must be installed to return a pytorch dataset.") raise ImportError("Pytorch must be installed to return a pytorch dataset.")
...@@ -341,12 +363,13 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -341,12 +363,13 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
"input_ids": ex.input_ids, "input_ids": ex.input_ids,
"attention_mask": ex.attention_mask, "attention_mask": ex.attention_mask,
"token_type_ids": ex.token_type_ids, "token_type_ids": ex.token_type_ids,
}, { },
{
"start_position": ex.start_position, "start_position": ex.start_position,
"end_position": ex.end_position, "end_position": ex.end_position,
"cls_index": ex.cls_index, "cls_index": ex.cls_index,
"p_mask": ex.p_mask, "p_mask": ex.p_mask,
} },
) )
return tf.data.Dataset.from_generator( return tf.data.Dataset.from_generator(
......
...@@ -24,6 +24,7 @@ from ...file_utils import is_tf_available, is_torch_available ...@@ -24,6 +24,7 @@ from ...file_utils import is_tf_available, is_torch_available
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class InputExample(object): class InputExample(object):
""" """
A single training/test example for simple sequence classification. A single training/test example for simple sequence classification.
...@@ -37,6 +38,7 @@ class InputExample(object): ...@@ -37,6 +38,7 @@ class InputExample(object):
label: (Optional) string. The label of the example. This should be label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples. specified for train and dev examples, but not for test examples.
""" """
def __init__(self, guid, text_a, text_b=None, label=None): def __init__(self, guid, text_a, text_b=None, label=None):
self.guid = guid self.guid = guid
self.text_a = text_a self.text_a = text_a
...@@ -99,14 +101,15 @@ class DataProcessor(object): ...@@ -99,14 +101,15 @@ class DataProcessor(object):
lines = [] lines = []
for line in reader: for line in reader:
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
line = list(unicode(cell, 'utf-8') for cell in line) line = list(unicode(cell, "utf-8") for cell in line)
lines.append(line) lines.append(line)
return lines return lines
class SingleSentenceClassificationProcessor(DataProcessor): class SingleSentenceClassificationProcessor(DataProcessor):
""" Generic processor for a single sentence classification data set.""" """ Generic processor for a single sentence classification data set."""
def __init__(self, labels=None, examples=None, mode='classification', verbose=False):
def __init__(self, labels=None, examples=None, mode="classification", verbose=False):
self.labels = [] if labels is None else labels self.labels = [] if labels is None else labels
self.examples = [] if examples is None else examples self.examples = [] if examples is None else examples
self.mode = mode self.mode = mode
...@@ -117,22 +120,24 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -117,22 +120,24 @@ class SingleSentenceClassificationProcessor(DataProcessor):
def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(idx, slice): if isinstance(idx, slice):
return SingleSentenceClassificationProcessor(labels=self.labels, return SingleSentenceClassificationProcessor(labels=self.labels, examples=self.examples[idx])
examples=self.examples[idx])
return self.examples[idx] return self.examples[idx]
@classmethod @classmethod
def create_from_csv(cls, file_name, split_name='', column_label=0, column_text=1, def create_from_csv(
column_id=None, skip_first_row=False, **kwargs): cls, file_name, split_name="", column_label=0, column_text=1, column_id=None, skip_first_row=False, **kwargs
):
processor = cls(**kwargs) processor = cls(**kwargs)
processor.add_examples_from_csv(file_name, processor.add_examples_from_csv(
split_name=split_name, file_name,
column_label=column_label, split_name=split_name,
column_text=column_text, column_label=column_label,
column_id=column_id, column_text=column_text,
skip_first_row=skip_first_row, column_id=column_id,
overwrite_labels=True, skip_first_row=skip_first_row,
overwrite_examples=True) overwrite_labels=True,
overwrite_examples=True,
)
return processor return processor
@classmethod @classmethod
...@@ -141,8 +146,17 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -141,8 +146,17 @@ class SingleSentenceClassificationProcessor(DataProcessor):
processor.add_examples(texts_or_text_and_labels, labels=labels) processor.add_examples(texts_or_text_and_labels, labels=labels)
return processor return processor
def add_examples_from_csv(self, file_name, split_name='', column_label=0, column_text=1, column_id=None, def add_examples_from_csv(
skip_first_row=False, overwrite_labels=False, overwrite_examples=False): self,
file_name,
split_name="",
column_label=0,
column_text=1,
column_id=None,
skip_first_row=False,
overwrite_labels=False,
overwrite_examples=False,
):
lines = self._read_tsv(file_name) lines = self._read_tsv(file_name)
if skip_first_row: if skip_first_row:
lines = lines[1:] lines = lines[1:]
...@@ -158,10 +172,13 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -158,10 +172,13 @@ class SingleSentenceClassificationProcessor(DataProcessor):
guid = "%s-%s" % (split_name, i) if split_name else "%s" % i guid = "%s-%s" % (split_name, i) if split_name else "%s" % i
ids.append(guid) ids.append(guid)
return self.add_examples(texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples) return self.add_examples(
texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples
)
def add_examples(self, texts_or_text_and_labels, labels=None, ids=None, def add_examples(
overwrite_labels=False, overwrite_examples=False): self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
):
assert labels is None or len(texts_or_text_and_labels) == len(labels) assert labels is None or len(texts_or_text_and_labels) == len(labels)
assert ids is None or len(texts_or_text_and_labels) == len(ids) assert ids is None or len(texts_or_text_and_labels) == len(ids)
if ids is None: if ids is None:
...@@ -192,13 +209,15 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -192,13 +209,15 @@ class SingleSentenceClassificationProcessor(DataProcessor):
return self.examples return self.examples
def get_features(self, def get_features(
tokenizer, self,
max_length=None, tokenizer,
pad_on_left=False, max_length=None,
pad_token=0, pad_on_left=False,
mask_padding_with_zero=True, pad_token=0,
return_tensors=None): mask_padding_with_zero=True,
return_tensors=None,
):
""" """
Convert examples in a list of ``InputFeatures`` Convert examples in a list of ``InputFeatures``
...@@ -231,9 +250,7 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -231,9 +250,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
logger.info("Tokenizing example %d", ex_index) logger.info("Tokenizing example %d", ex_index)
input_ids = tokenizer.encode( input_ids = tokenizer.encode(
example.text_a, example.text_a, add_special_tokens=True, max_length=min(max_length, tokenizer.max_len),
add_special_tokens=True,
max_length=min(max_length, tokenizer.max_len),
) )
all_input_ids.append(input_ids) all_input_ids.append(input_ids)
...@@ -256,8 +273,12 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -256,8 +273,12 @@ class SingleSentenceClassificationProcessor(DataProcessor):
input_ids = input_ids + ([pad_token] * padding_length) input_ids = input_ids + ([pad_token] * padding_length)
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
assert len(input_ids) == batch_length, "Error with input length {} vs {}".format(len(input_ids), batch_length) assert len(input_ids) == batch_length, "Error with input length {} vs {}".format(
assert len(attention_mask) == batch_length, "Error with input length {} vs {}".format(len(attention_mask), batch_length) len(input_ids), batch_length
)
assert len(attention_mask) == batch_length, "Error with input length {} vs {}".format(
len(attention_mask), batch_length
)
if self.mode == "classification": if self.mode == "classification":
label = label_map[example.label] label = label_map[example.label]
...@@ -273,36 +294,31 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -273,36 +294,31 @@ class SingleSentenceClassificationProcessor(DataProcessor):
logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask])) logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
logger.info("label: %s (id = %d)" % (example.label, label)) logger.info("label: %s (id = %d)" % (example.label, label))
features.append( features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, label=label))
InputFeatures(input_ids=input_ids,
attention_mask=attention_mask,
label=label))
if return_tensors is None: if return_tensors is None:
return features return features
elif return_tensors == 'tf': elif return_tensors == "tf":
if not is_tf_available(): if not is_tf_available():
raise ImportError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported") raise ImportError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
import tensorflow as tf import tensorflow as tf
def gen(): def gen():
for ex in features: for ex in features:
yield ({'input_ids': ex.input_ids, yield ({"input_ids": ex.input_ids, "attention_mask": ex.attention_mask}, ex.label)
'attention_mask': ex.attention_mask},
ex.label) dataset = tf.data.Dataset.from_generator(
gen,
dataset = tf.data.Dataset.from_generator(gen, ({"input_ids": tf.int32, "attention_mask": tf.int32}, tf.int64),
({'input_ids': tf.int32, ({"input_ids": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None])}, tf.TensorShape([])),
'attention_mask': tf.int32}, )
tf.int64),
({'input_ids': tf.TensorShape([None]),
'attention_mask': tf.TensorShape([None])},
tf.TensorShape([])))
return dataset return dataset
elif return_tensors == 'pt': elif return_tensors == "pt":
if not is_torch_available(): if not is_torch_available():
raise ImportError("return_tensors set to 'pt' but PyTorch can't be imported") raise ImportError("return_tensors set to 'pt' but PyTorch can't be imported")
import torch import torch
from torch.utils.data import TensorDataset from torch.utils.data import TensorDataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
if self.mode == "classification": if self.mode == "classification":
......
...@@ -24,11 +24,12 @@ from .utils import DataProcessor, InputExample ...@@ -24,11 +24,12 @@ from .utils import DataProcessor, InputExample
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class XnliProcessor(DataProcessor): class XnliProcessor(DataProcessor):
"""Processor for the XNLI dataset. """Processor for the XNLI dataset.
Adapted from https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207""" Adapted from https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207"""
def __init__(self, language, train_language = None): def __init__(self, language, train_language=None):
self.language = language self.language = language
self.train_language = train_language self.train_language = train_language
...@@ -40,13 +41,12 @@ class XnliProcessor(DataProcessor): ...@@ -40,13 +41,12 @@ class XnliProcessor(DataProcessor):
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % ('train', i) guid = "%s-%s" % ("train", i)
text_a = line[0] text_a = line[0]
text_b = line[1] text_b = line[1]
label = "contradiction" if line[2] == "contradictory" else line[2] label = "contradiction" if line[2] == "contradictory" else line[2]
assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str) assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
...@@ -59,19 +59,19 @@ class XnliProcessor(DataProcessor): ...@@ -59,19 +59,19 @@ class XnliProcessor(DataProcessor):
language = line[0] language = line[0]
if language != self.language: if language != self.language:
continue continue
guid = "%s-%s" % ('test', i) guid = "%s-%s" % ("test", i)
text_a = line[6] text_a = line[6]
text_b = line[7] text_b = line[7]
label = line[1] label = line[1]
assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str) assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["contradiction", "entailment", "neutral"] return ["contradiction", "entailment", "neutral"]
xnli_processors = { xnli_processors = {
"xnli": XnliProcessor, "xnli": XnliProcessor,
} }
......
...@@ -3,7 +3,7 @@ Utilities for working with the local dataset cache. ...@@ -3,7 +3,7 @@ Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors. Copyright by the AllenNLP authors.
""" """
from __future__ import (absolute_import, division, print_function, unicode_literals) from __future__ import absolute_import, division, print_function, unicode_literals
import sys import sys
import json import json
...@@ -29,9 +29,10 @@ from filelock import FileLock ...@@ -29,9 +29,10 @@ from filelock import FileLock
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
try: try:
os.environ.setdefault('USE_TORCH', 'YES') os.environ.setdefault("USE_TORCH", "YES")
if os.environ['USE_TORCH'].upper() in ('1', 'ON', 'YES'): if os.environ["USE_TORCH"].upper() in ("1", "ON", "YES"):
import torch import torch
_torch_available = True # pylint: disable=invalid-name _torch_available = True # pylint: disable=invalid-name
logger.info("PyTorch version {} available.".format(torch.__version__)) logger.info("PyTorch version {} available.".format(torch.__version__))
else: else:
...@@ -41,10 +42,11 @@ except ImportError: ...@@ -41,10 +42,11 @@ except ImportError:
_torch_available = False # pylint: disable=invalid-name _torch_available = False # pylint: disable=invalid-name
try: try:
os.environ.setdefault('USE_TF', 'YES') os.environ.setdefault("USE_TF", "YES")
if os.environ['USE_TF'].upper() in ('1', 'ON', 'YES'): if os.environ["USE_TF"].upper() in ("1", "ON", "YES"):
import tensorflow as tf import tensorflow as tf
assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name _tf_available = True # pylint: disable=invalid-name
logger.info("TensorFlow version {} available.".format(tf.__version__)) logger.info("TensorFlow version {} available.".format(tf.__version__))
else: else:
...@@ -55,12 +57,13 @@ except (ImportError, AssertionError): ...@@ -55,12 +57,13 @@ except (ImportError, AssertionError):
try: try:
from torch.hub import _get_torch_home from torch.hub import _get_torch_home
torch_cache_home = _get_torch_home() torch_cache_home = _get_torch_home()
except ImportError: except ImportError:
torch_cache_home = os.path.expanduser( torch_cache_home = os.path.expanduser(
os.getenv('TORCH_HOME', os.path.join( os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) )
default_cache_path = os.path.join(torch_cache_home, 'transformers') default_cache_path = os.path.join(torch_cache_home, "transformers")
try: try:
from urllib.parse import urlparse from urllib.parse import urlparse
...@@ -69,19 +72,21 @@ except ImportError: ...@@ -69,19 +72,21 @@ except ImportError:
try: try:
from pathlib import Path from pathlib import Path
PYTORCH_PRETRAINED_BERT_CACHE = Path( PYTORCH_PRETRAINED_BERT_CACHE = Path(
os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))) os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path))
)
except (AttributeError, ImportError): except (AttributeError, ImportError):
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE', PYTORCH_PRETRAINED_BERT_CACHE = os.getenv(
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', "PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
default_cache_path)) )
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_NAME = "pytorch_model.bin"
TF2_WEIGHTS_NAME = 'tf_model.h5' TF2_WEIGHTS_NAME = "tf_model.h5"
TF_WEIGHTS_NAME = 'model.ckpt' TF_WEIGHTS_NAME = "model.ckpt"
CONFIG_NAME = "config.json" CONFIG_NAME = "config.json"
MODEL_CARD_NAME = "modelcard.json" MODEL_CARD_NAME = "modelcard.json"
...@@ -95,38 +100,48 @@ CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net" ...@@ -95,38 +100,48 @@ CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net"
def is_torch_available(): def is_torch_available():
return _torch_available return _torch_available
def is_tf_available(): def is_tf_available():
return _tf_available return _tf_available
if not six.PY2: if not six.PY2:
def add_start_docstrings(*docstr): def add_start_docstrings(*docstr):
def docstring_decorator(fn): def docstring_decorator(fn):
fn.__doc__ = ''.join(docstr) + fn.__doc__ fn.__doc__ = "".join(docstr) + fn.__doc__
return fn return fn
return docstring_decorator return docstring_decorator
def add_end_docstrings(*docstr): def add_end_docstrings(*docstr):
def docstring_decorator(fn): def docstring_decorator(fn):
fn.__doc__ = fn.__doc__ + ''.join(docstr) fn.__doc__ = fn.__doc__ + "".join(docstr)
return fn return fn
return docstring_decorator return docstring_decorator
else: else:
# Not possible to update class docstrings on python2 # Not possible to update class docstrings on python2
def add_start_docstrings(*docstr): def add_start_docstrings(*docstr):
def docstring_decorator(fn): def docstring_decorator(fn):
return fn return fn
return docstring_decorator return docstring_decorator
def add_end_docstrings(*docstr): def add_end_docstrings(*docstr):
def docstring_decorator(fn): def docstring_decorator(fn):
return fn return fn
return docstring_decorator return docstring_decorator
def is_remote_url(url_or_filename): def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename) parsed = urlparse(url_or_filename)
return parsed.scheme in ('http', 'https', 's3') return parsed.scheme in ("http", "https", "s3")
def hf_bucket_url(identifier, postfix=None, cdn=False): def hf_bucket_url(identifier, postfix=None, cdn=False):
endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
...@@ -145,17 +160,17 @@ def url_to_filename(url, etag=None): ...@@ -145,17 +160,17 @@ def url_to_filename(url, etag=None):
so that TF 2.0 can identify it as a HDF5 file so that TF 2.0 can identify it as a HDF5 file
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
""" """
url_bytes = url.encode('utf-8') url_bytes = url.encode("utf-8")
url_hash = sha256(url_bytes) url_hash = sha256(url_bytes)
filename = url_hash.hexdigest() filename = url_hash.hexdigest()
if etag: if etag:
etag_bytes = etag.encode('utf-8') etag_bytes = etag.encode("utf-8")
etag_hash = sha256(etag_bytes) etag_hash = sha256(etag_bytes)
filename += '.' + etag_hash.hexdigest() filename += "." + etag_hash.hexdigest()
if url.endswith('.h5'): if url.endswith(".h5"):
filename += '.h5' filename += ".h5"
return filename return filename
...@@ -174,19 +189,21 @@ def filename_to_url(filename, cache_dir=None): ...@@ -174,19 +189,21 @@ def filename_to_url(filename, cache_dir=None):
if not os.path.exists(cache_path): if not os.path.exists(cache_path):
raise EnvironmentError("file {} not found".format(cache_path)) raise EnvironmentError("file {} not found".format(cache_path))
meta_path = cache_path + '.json' meta_path = cache_path + ".json"
if not os.path.exists(meta_path): if not os.path.exists(meta_path):
raise EnvironmentError("file {} not found".format(meta_path)) raise EnvironmentError("file {} not found".format(meta_path))
with open(meta_path, encoding="utf-8") as meta_file: with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file) metadata = json.load(meta_file)
url = metadata['url'] url = metadata["url"]
etag = metadata['etag'] etag = metadata["etag"]
return url, etag return url, etag
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None): def cached_path(
url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None
):
""" """
Given something that might be a URL (or might be a local path), Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and determine which. If it's a URL, download the file and cache it, and
...@@ -207,13 +224,18 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N ...@@ -207,13 +224,18 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
if is_remote_url(url_or_filename): if is_remote_url(url_or_filename):
# URL, so get it from the cache (downloading if necessary) # URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir=cache_dir, return get_from_cache(
force_download=force_download, proxies=proxies, url_or_filename,
resume_download=resume_download, user_agent=user_agent) cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
user_agent=user_agent,
)
elif os.path.exists(url_or_filename): elif os.path.exists(url_or_filename):
# File, and it exists. # File, and it exists.
return url_or_filename return url_or_filename
elif urlparse(url_or_filename).scheme == '': elif urlparse(url_or_filename).scheme == "":
# File, but it doesn't exist. # File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename)) raise EnvironmentError("file {} not found".format(url_or_filename))
else: else:
...@@ -273,31 +295,35 @@ def s3_get(url, temp_file, proxies=None): ...@@ -273,31 +295,35 @@ def s3_get(url, temp_file, proxies=None):
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None): def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0]) ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
if isinstance(user_agent, dict): if isinstance(user_agent, dict):
ua += "; " + "; ".join( ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
"{}/{}".format(k, v) for k, v in user_agent.items()
)
elif isinstance(user_agent, six.string_types): elif isinstance(user_agent, six.string_types):
ua += "; "+ user_agent ua += "; " + user_agent
headers = { headers = {"user-agent": ua}
"user-agent": ua
}
if resume_size > 0: if resume_size > 0:
headers['Range'] = 'bytes=%d-' % (resume_size,) headers["Range"] = "bytes=%d-" % (resume_size,)
response = requests.get(url, stream=True, proxies=proxies, headers=headers) response = requests.get(url, stream=True, proxies=proxies, headers=headers)
if response.status_code == 416: # Range not satisfiable if response.status_code == 416: # Range not satisfiable
return return
content_length = response.headers.get('Content-Length') content_length = response.headers.get("Content-Length")
total = resume_size + int(content_length) if content_length is not None else None total = resume_size + int(content_length) if content_length is not None else None
progress = tqdm(unit="B", unit_scale=True, total=total, initial=resume_size, progress = tqdm(
desc="Downloading", disable=bool(logger.level<=logging.INFO)) unit="B",
unit_scale=True,
total=total,
initial=resume_size,
desc="Downloading",
disable=bool(logger.level <= logging.INFO),
)
for chunk in response.iter_content(chunk_size=1024): for chunk in response.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks
progress.update(len(chunk)) progress.update(len(chunk))
temp_file.write(chunk) temp_file.write(chunk)
progress.close() progress.close()
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None): def get_from_cache(
url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None
):
""" """
Given a URL, look for the corresponding dataset in the local cache. Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file. If it's not there, download it. Then return the path to the cached file.
...@@ -326,7 +352,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag ...@@ -326,7 +352,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
etag = None etag = None
if sys.version_info[0] == 2 and etag is not None: if sys.version_info[0] == 2 and etag is not None:
etag = etag.decode('utf-8') etag = etag.decode("utf-8")
filename = url_to_filename(url, etag) filename = url_to_filename(url, etag)
# get cache path to put the file # get cache path to put the file
...@@ -337,22 +363,24 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag ...@@ -337,22 +363,24 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
if not os.path.exists(cache_path) and etag is None: if not os.path.exists(cache_path) and etag is None:
matching_files = [ matching_files = [
file file
for file in fnmatch.filter(os.listdir(cache_dir), filename + '.*') for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
if not file.endswith('.json') and not file.endswith('.lock') if not file.endswith(".json") and not file.endswith(".lock")
] ]
if matching_files: if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1]) cache_path = os.path.join(cache_dir, matching_files[-1])
# Prevent parallel downloads of the same file with a lock. # Prevent parallel downloads of the same file with a lock.
lock_path = cache_path + '.lock' lock_path = cache_path + ".lock"
with FileLock(lock_path): with FileLock(lock_path):
if resume_download: if resume_download:
incomplete_path = cache_path + '.incomplete' incomplete_path = cache_path + ".incomplete"
@contextmanager @contextmanager
def _resumable_file_manager(): def _resumable_file_manager():
with open(incomplete_path,'a+b') as f: with open(incomplete_path, "a+b") as f:
yield f yield f
temp_file_manager = _resumable_file_manager temp_file_manager = _resumable_file_manager
if os.path.exists(incomplete_path): if os.path.exists(incomplete_path):
resume_size = os.stat(incomplete_path).st_size resume_size = os.stat(incomplete_path).st_size
...@@ -366,7 +394,9 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag ...@@ -366,7 +394,9 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
# Download to temporary file, then copy to cache dir once finished. # Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted. # Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file: with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) logger.info(
"%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name
)
# GET file object # GET file object
if url.startswith("s3://"): if url.startswith("s3://"):
...@@ -383,12 +413,12 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag ...@@ -383,12 +413,12 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
os.rename(temp_file.name, cache_path) os.rename(temp_file.name, cache_path)
logger.info("creating metadata file for %s", cache_path) logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag} meta = {"url": url, "etag": etag}
meta_path = cache_path + '.json' meta_path = cache_path + ".json"
with open(meta_path, 'w') as meta_file: with open(meta_path, "w") as meta_file:
output_string = json.dumps(meta) output_string = json.dumps(meta)
if sys.version_info[0] == 2 and isinstance(output_string, str): if sys.version_info[0] == 2 and isinstance(output_string, str):
output_string = unicode(output_string, 'utf-8') # The beauty of python 2 output_string = unicode(output_string, "utf-8") # The beauty of python 2
meta_file.write(output_string) meta_file.write(output_string)
return cache_path return cache_path
...@@ -24,13 +24,14 @@ from tqdm import tqdm ...@@ -24,13 +24,14 @@ from tqdm import tqdm
ENDPOINT = "https://huggingface.co" ENDPOINT = "https://huggingface.co"
class S3Obj: class S3Obj:
def __init__( def __init__(
self, self,
filename, # type: str filename, # type: str
LastModified, # type: str LastModified, # type: str
ETag, # type: str ETag, # type: str
Size, # type: int Size, # type: int
**kwargs **kwargs
): ):
self.filename = filename self.filename = filename
...@@ -43,13 +44,13 @@ class PresignedUrl: ...@@ -43,13 +44,13 @@ class PresignedUrl:
def __init__( def __init__(
self, self,
write, # type: str write, # type: str
access, # type: str access, # type: str
type, # type: str type, # type: str
**kwargs **kwargs
): ):
self.write = write self.write = write
self.access = access self.access = access
self.type = type # mime-type to send to S3. self.type = type # mime-type to send to S3.
class HfApi: class HfApi:
...@@ -58,8 +59,8 @@ class HfApi: ...@@ -58,8 +59,8 @@ class HfApi:
def login( def login(
self, self,
username, # type: str username, # type: str
password, # type: str password, # type: str
): ):
# type: (...) -> str # type: (...) -> str
""" """
...@@ -78,8 +79,7 @@ class HfApi: ...@@ -78,8 +79,7 @@ class HfApi:
return d["token"] return d["token"]
def whoami( def whoami(
self, self, token, # type: str
token, # type: str
): ):
# type: (...) -> str # type: (...) -> str
""" """
...@@ -106,11 +106,7 @@ class HfApi: ...@@ -106,11 +106,7 @@ class HfApi:
Call HF API to get a presigned url to upload `filename` to S3. Call HF API to get a presigned url to upload `filename` to S3.
""" """
path = "{}/api/presign".format(self.endpoint) path = "{}/api/presign".format(self.endpoint)
r = requests.post( r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename},)
path,
headers={"authorization": "Bearer {}".format(token)},
json={"filename": filename},
)
r.raise_for_status() r.raise_for_status()
d = r.json() d = r.json()
return PresignedUrl(**d) return PresignedUrl(**d)
...@@ -126,16 +122,14 @@ class HfApi: ...@@ -126,16 +122,14 @@ class HfApi:
urls = self.presign(token, filename=filename) urls = self.presign(token, filename=filename)
# streaming upload: # streaming upload:
# https://2.python-requests.org/en/master/user/advanced/#streaming-uploads # https://2.python-requests.org/en/master/user/advanced/#streaming-uploads
# #
# Even though we presign with the correct content-type, # Even though we presign with the correct content-type,
# the client still has to specify it when uploading the file. # the client still has to specify it when uploading the file.
with open(filepath, "rb") as f: with open(filepath, "rb") as f:
pf = TqdmProgressFileReader(f) pf = TqdmProgressFileReader(f)
data = f if pf.total_size > 0 else "" data = f if pf.total_size > 0 else ""
r = requests.put(urls.write, data=data, headers={ r = requests.put(urls.write, data=data, headers={"content-type": urls.type,})
"content-type": urls.type,
})
r.raise_for_status() r.raise_for_status()
pf.close() pf.close()
return urls.access return urls.access
...@@ -152,7 +146,6 @@ class HfApi: ...@@ -152,7 +146,6 @@ class HfApi:
return [S3Obj(**x) for x in d] return [S3Obj(**x) for x in d]
class TqdmProgressFileReader: class TqdmProgressFileReader:
""" """
Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`) Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`)
...@@ -161,12 +154,12 @@ class TqdmProgressFileReader: ...@@ -161,12 +154,12 @@ class TqdmProgressFileReader:
see github.com/huggingface/transformers/pull/2078#discussion_r354739608 see github.com/huggingface/transformers/pull/2078#discussion_r354739608
for implementation details. for implementation details.
""" """
def __init__( def __init__(
self, self, f # type: io.BufferedReader
f # type: io.BufferedReader
): ):
self.f = f self.f = f
self.total_size = os.fstat(f.fileno()).st_size # type: int self.total_size = os.fstat(f.fileno()).st_size # type: int
self.pbar = tqdm(total=self.total_size, leave=False) self.pbar = tqdm(total=self.total_size, leave=False)
if six.PY3: if six.PY3:
# does not work unless PY3 # does not work unless PY3
...@@ -182,7 +175,6 @@ class TqdmProgressFileReader: ...@@ -182,7 +175,6 @@ class TqdmProgressFileReader:
self.pbar.close() self.pbar.close()
class HfFolder: class HfFolder:
path_token = expanduser("~/.huggingface/token") path_token = expanduser("~/.huggingface/token")
...@@ -201,7 +193,7 @@ class HfFolder: ...@@ -201,7 +193,7 @@ class HfFolder:
if e.errno != os.errno.EEXIST: if e.errno != os.errno.EEXIST:
raise e raise e
pass pass
with open(cls.path_token, 'w+') as f: with open(cls.path_token, "w+") as f:
f.write(token) f.write(token)
@classmethod @classmethod
...@@ -210,7 +202,7 @@ class HfFolder: ...@@ -210,7 +202,7 @@ class HfFolder:
Get token or None if not existent. Get token or None if not existent.
""" """
try: try:
with open(cls.path_token, 'r') as f: with open(cls.path_token, "r") as f:
return f.read() return f.read()
except: except:
# this is too wide. When Py2 is dead use: # this is too wide. When Py2 is dead use:
......
...@@ -14,8 +14,7 @@ ...@@ -14,8 +14,7 @@
# limitations under the License. # limitations under the License.
""" Configuration base class and utilities.""" """ Configuration base class and utilities."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import copy import copy
import json import json
...@@ -25,8 +24,15 @@ from io import open ...@@ -25,8 +24,15 @@ from io import open
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .file_utils import CONFIG_NAME, MODEL_CARD_NAME, WEIGHTS_NAME, TF2_WEIGHTS_NAME, \ from .file_utils import (
cached_path, is_remote_url, hf_bucket_url CONFIG_NAME,
MODEL_CARD_NAME,
WEIGHTS_NAME,
TF2_WEIGHTS_NAME,
cached_path,
is_remote_url,
hf_bucket_url,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -48,17 +54,18 @@ class ModelCard(object): ...@@ -48,17 +54,18 @@ class ModelCard(object):
Parameters: Parameters:
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
# Recomended attributes from https://arxiv.org/abs/1810.03993 (see papers) # Recomended attributes from https://arxiv.org/abs/1810.03993 (see papers)
self.model_details = kwargs.pop('model_details', {}) self.model_details = kwargs.pop("model_details", {})
self.intended_use = kwargs.pop('intended_use', {}) self.intended_use = kwargs.pop("intended_use", {})
self.factors = kwargs.pop('factors', {}) self.factors = kwargs.pop("factors", {})
self.metrics = kwargs.pop('metrics', {}) self.metrics = kwargs.pop("metrics", {})
self.evaluation_data = kwargs.pop('evaluation_data', {}) self.evaluation_data = kwargs.pop("evaluation_data", {})
self.training_data = kwargs.pop('training_data', {}) self.training_data = kwargs.pop("training_data", {})
self.quantitative_analyses = kwargs.pop('quantitative_analyses', {}) self.quantitative_analyses = kwargs.pop("quantitative_analyses", {})
self.ethical_considerations = kwargs.pop('ethical_considerations', {}) self.ethical_considerations = kwargs.pop("ethical_considerations", {})
self.caveats_and_recommendations = kwargs.pop('caveats_and_recommendations', {}) self.caveats_and_recommendations = kwargs.pop("caveats_and_recommendations", {})
# Open additional attributes # Open additional attributes
for key, value in kwargs.items(): for key, value in kwargs.items():
...@@ -122,10 +129,10 @@ class ModelCard(object): ...@@ -122,10 +129,10 @@ class ModelCard(object):
modelcard = ModelCard.from_pretrained('bert-base-uncased', output_attention=True, foo=False) modelcard = ModelCard.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
""" """
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop("cache_dir", None)
proxies = kwargs.pop('proxies', None) proxies = kwargs.pop("proxies", None)
find_from_standard_name = kwargs.pop('find_from_standard_name', True) find_from_standard_name = kwargs.pop("find_from_standard_name", True)
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP: if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
# For simplicity we use the same pretrained url than the configuration files # For simplicity we use the same pretrained url than the configuration files
...@@ -145,36 +152,43 @@ class ModelCard(object): ...@@ -145,36 +152,43 @@ class ModelCard(object):
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, force_download=True, resolved_model_card_file = cached_path(
proxies=proxies, resume_download=False) model_card_file, cache_dir=cache_dir, force_download=True, proxies=proxies, resume_download=False
)
if resolved_model_card_file == model_card_file: if resolved_model_card_file == model_card_file:
logger.info("loading model card file {}".format(model_card_file)) logger.info("loading model card file {}".format(model_card_file))
else: else:
logger.info("loading model card file {} from cache at {}".format( logger.info(
model_card_file, resolved_model_card_file)) "loading model card file {} from cache at {}".format(model_card_file, resolved_model_card_file)
)
# Load model card # Load model card
modelcard = cls.from_json_file(resolved_model_card_file) modelcard = cls.from_json_file(resolved_model_card_file)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP: if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.warning("Couldn't reach server at '{}' to download model card file.".format( logger.warning("Couldn't reach server at '{}' to download model card file.".format(model_card_file))
model_card_file))
else: else:
logger.warning("Model name '{}' was not found in model name list ({}). " \ logger.warning(
"We assumed '{}' was a path or url to a model card file named {} or " \ "Model name '{}' was not found in model name list ({}). "
"a directory containing such a file but couldn't find any such file at this path or url.".format( "We assumed '{}' was a path or url to a model card file named {} or "
"a directory containing such a file but couldn't find any such file at this path or url.".format(
pretrained_model_name_or_path, pretrained_model_name_or_path,
', '.join(ALL_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), ", ".join(ALL_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
model_card_file, MODEL_CARD_NAME)) model_card_file,
MODEL_CARD_NAME,
)
)
logger.warning("Creating an empty model card.") logger.warning("Creating an empty model card.")
# We fall back on creating an empty model card # We fall back on creating an empty model card
modelcard = cls() modelcard = cls()
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning("Couldn't reach server at '{}' to download model card file or " logger.warning(
"model card file is not a valid JSON file. " "Couldn't reach server at '{}' to download model card file or "
"Please check network or file content here: {}.".format(model_card_file, resolved_model_card_file)) "model card file is not a valid JSON file. "
"Please check network or file content here: {}.".format(model_card_file, resolved_model_card_file)
)
logger.warning("Creating an empty model card.") logger.warning("Creating an empty model card.")
# We fall back on creating an empty model card # We fall back on creating an empty model card
...@@ -203,7 +217,7 @@ class ModelCard(object): ...@@ -203,7 +217,7 @@ class ModelCard(object):
@classmethod @classmethod
def from_json_file(cls, json_file): def from_json_file(cls, json_file):
"""Constructs a `ModelCard` from a json file of parameters.""" """Constructs a `ModelCard` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader: with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read() text = reader.read()
dict_obj = json.loads(text) dict_obj = json.loads(text)
return cls(**dict_obj) return cls(**dict_obj)
...@@ -225,5 +239,5 @@ class ModelCard(object): ...@@ -225,5 +239,5 @@ class ModelCard(object):
def to_json_file(self, json_file_path): def to_json_file(self, json_file_path):
""" Save this instance to a json file.""" """ Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer: with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string()) writer.write(self.to_json_string())
# coding=utf-8 # coding=utf-8
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team. # Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
# #
...@@ -30,14 +29,14 @@ logger = logging.getLogger(__name__) ...@@ -30,14 +29,14 @@ logger = logging.getLogger(__name__)
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = { ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-pytorch_model.bin", "albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-pytorch_model.bin",
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-pytorch_model.bin", "albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-pytorch_model.bin",
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-pytorch_model.bin", "albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-pytorch_model.bin",
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-pytorch_model.bin", "albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-pytorch_model.bin",
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-pytorch_model.bin", "albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-pytorch_model.bin",
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-pytorch_model.bin", "albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-pytorch_model.bin",
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-pytorch_model.bin", "albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-pytorch_model.bin",
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-pytorch_model.bin", "albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-pytorch_model.bin",
} }
...@@ -48,8 +47,10 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path): ...@@ -48,8 +47,10 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
except ImportError: except ImportError:
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " logger.error(
"https://www.tensorflow.org/install/ for installation instructions.") "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise raise
tf_path = os.path.abspath(tf_checkpoint_path) tf_path = os.path.abspath(tf_checkpoint_path)
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
...@@ -65,7 +66,7 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path): ...@@ -65,7 +66,7 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
for name, array in zip(names, arrays): for name, array in zip(names, arrays):
print(name) print(name)
for name, array in zip(names, arrays): for name, array in zip(names, arrays):
original_name = name original_name = name
...@@ -75,10 +76,10 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path): ...@@ -75,10 +76,10 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
# Renaming and simplifying # Renaming and simplifying
name = name.replace("ffn_1", "ffn") name = name.replace("ffn_1", "ffn")
name = name.replace("bert/", "albert/") name = name.replace("bert/", "albert/")
name = name.replace("attention_1", "attention") name = name.replace("attention_1", "attention")
name = name.replace("transform/", "") name = name.replace("transform/", "")
name = name.replace("LayerNorm_1", "full_layer_layer_norm") name = name.replace("LayerNorm_1", "full_layer_layer_norm")
name = name.replace("LayerNorm", "attention/LayerNorm") name = name.replace("LayerNorm", "attention/LayerNorm")
name = name.replace("transformer/", "") name = name.replace("transformer/", "")
# The feed forward layer had an 'intermediate' step which has been abstracted away # The feed forward layer had an 'intermediate' step which has been abstracted away
...@@ -97,19 +98,19 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path): ...@@ -97,19 +98,19 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
name = name.replace("predictions/attention", "predictions") name = name.replace("predictions/attention", "predictions")
# Naming was changed to be more explicit # Naming was changed to be more explicit
name = name.replace("embeddings/attention", "embeddings") name = name.replace("embeddings/attention", "embeddings")
name = name.replace("inner_group_", "albert_layers/") name = name.replace("inner_group_", "albert_layers/")
name = name.replace("group_", "albert_layer_groups/") name = name.replace("group_", "albert_layer_groups/")
# Classifier # Classifier
if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name): if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
name = "classifier/" + name name = "classifier/" + name
# No ALBERT model currently handles the next sentence prediction task # No ALBERT model currently handles the next sentence prediction task
if "seq_relationship" in name: if "seq_relationship" in name:
continue continue
name = name.split('/') name = name.split("/")
# Ignore the gradients applied by the LAMB/ADAM optimizers. # Ignore the gradients applied by the LAMB/ADAM optimizers.
if "adam_m" in name or "adam_v" in name or "global_step" in name: if "adam_m" in name or "adam_v" in name or "global_step" in name:
...@@ -118,19 +119,19 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path): ...@@ -118,19 +119,19 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
pointer = model pointer = model
for m_name in name: for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name): if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
l = re.split(r'_(\d+)', m_name) l = re.split(r"_(\d+)", m_name)
else: else:
l = [m_name] l = [m_name]
if l[0] == 'kernel' or l[0] == 'gamma': if l[0] == "kernel" or l[0] == "gamma":
pointer = getattr(pointer, 'weight') pointer = getattr(pointer, "weight")
elif l[0] == 'output_bias' or l[0] == 'beta': elif l[0] == "output_bias" or l[0] == "beta":
pointer = getattr(pointer, 'bias') pointer = getattr(pointer, "bias")
elif l[0] == 'output_weights': elif l[0] == "output_weights":
pointer = getattr(pointer, 'weight') pointer = getattr(pointer, "weight")
elif l[0] == 'squad': elif l[0] == "squad":
pointer = getattr(pointer, 'classifier') pointer = getattr(pointer, "classifier")
else: else:
try: try:
pointer = getattr(pointer, l[0]) pointer = getattr(pointer, l[0])
...@@ -141,9 +142,9 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path): ...@@ -141,9 +142,9 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
num = int(l[1]) num = int(l[1])
pointer = pointer[num] pointer = pointer[num]
if m_name[-11:] == '_embeddings': if m_name[-11:] == "_embeddings":
pointer = getattr(pointer, 'weight') pointer = getattr(pointer, "weight")
elif m_name == 'kernel': elif m_name == "kernel":
array = np.transpose(array) array = np.transpose(array)
try: try:
assert pointer.shape == array.shape assert pointer.shape == array.shape
...@@ -160,6 +161,7 @@ class AlbertEmbeddings(BertEmbeddings): ...@@ -160,6 +161,7 @@ class AlbertEmbeddings(BertEmbeddings):
""" """
Construct the embeddings from word, position and token_type embeddings. Construct the embeddings from word, position and token_type embeddings.
""" """
def __init__(self, config): def __init__(self, config):
super(AlbertEmbeddings, self).__init__(config) super(AlbertEmbeddings, self).__init__(config)
...@@ -175,7 +177,7 @@ class AlbertAttention(BertSelfAttention): ...@@ -175,7 +177,7 @@ class AlbertAttention(BertSelfAttention):
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.attention_head_size = config.hidden_size // config.num_attention_heads self.attention_head_size = config.hidden_size // config.num_attention_heads
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
...@@ -237,10 +239,13 @@ class AlbertAttention(BertSelfAttention): ...@@ -237,10 +239,13 @@ class AlbertAttention(BertSelfAttention):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
reshaped_context_layer = context_layer.view(*new_context_layer_shape) reshaped_context_layer = context_layer.view(*new_context_layer_shape)
# Should find a better way to do this # Should find a better way to do this
w = self.dense.weight.t().view(self.num_attention_heads, self.attention_head_size, self.hidden_size).to(context_layer.dtype) w = (
self.dense.weight.t()
.view(self.num_attention_heads, self.attention_head_size, self.hidden_size)
.to(context_layer.dtype)
)
b = self.dense.bias.to(context_layer.dtype) b = self.dense.bias.to(context_layer.dtype)
projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
...@@ -252,11 +257,11 @@ class AlbertAttention(BertSelfAttention): ...@@ -252,11 +257,11 @@ class AlbertAttention(BertSelfAttention):
class AlbertLayer(nn.Module): class AlbertLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super(AlbertLayer, self).__init__() super(AlbertLayer, self).__init__()
self.config = config self.config = config
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = AlbertAttention(config) self.attention = AlbertAttention(config)
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size) self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size) self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
self.activation = ACT2FN[config.hidden_act] self.activation = ACT2FN[config.hidden_act]
...@@ -273,7 +278,7 @@ class AlbertLayer(nn.Module): ...@@ -273,7 +278,7 @@ class AlbertLayer(nn.Module):
class AlbertLayerGroup(nn.Module): class AlbertLayerGroup(nn.Module):
def __init__(self, config): def __init__(self, config):
super(AlbertLayerGroup, self).__init__() super(AlbertLayerGroup, self).__init__()
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)]) self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
...@@ -303,7 +308,7 @@ class AlbertLayerGroup(nn.Module): ...@@ -303,7 +308,7 @@ class AlbertLayerGroup(nn.Module):
class AlbertTransformer(nn.Module): class AlbertTransformer(nn.Module):
def __init__(self, config): def __init__(self, config):
super(AlbertTransformer, self).__init__() super(AlbertTransformer, self).__init__()
self.config = config self.config = config
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
...@@ -327,8 +332,12 @@ class AlbertTransformer(nn.Module): ...@@ -327,8 +332,12 @@ class AlbertTransformer(nn.Module):
# Index of the layer inside the group # Index of the layer inside the group
layer_idx = int(i - group_idx * layers_per_group) layer_idx = int(i - group_idx * layers_per_group)
layer_group_output = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask[group_idx*layers_per_group:(group_idx+1)*layers_per_group]) layer_group_output = self.albert_layer_groups[group_idx](
hidden_states,
attention_mask,
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
)
hidden_states = layer_group_output[0] hidden_states = layer_group_output[0]
if self.output_attentions: if self.output_attentions:
...@@ -337,7 +346,6 @@ class AlbertTransformer(nn.Module): ...@@ -337,7 +346,6 @@ class AlbertTransformer(nn.Module):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,) outputs = (hidden_states,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
...@@ -346,11 +354,11 @@ class AlbertTransformer(nn.Module): ...@@ -346,11 +354,11 @@ class AlbertTransformer(nn.Module):
return outputs # last-layer hidden state, (all hidden states), (all attentions) return outputs # last-layer hidden state, (all hidden states), (all attentions)
class AlbertPreTrainedModel(PreTrainedModel): class AlbertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = AlbertConfig config_class = AlbertConfig
pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "albert" base_model_prefix = "albert"
...@@ -431,8 +439,12 @@ ALBERT_INPUTS_DOCSTRING = r""" ...@@ -431,8 +439,12 @@ ALBERT_INPUTS_DOCSTRING = r"""
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
""" """
@add_start_docstrings("The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.",
ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING) @add_start_docstrings(
"The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.",
ALBERT_START_DOCSTRING,
ALBERT_INPUTS_DOCSTRING,
)
class AlbertModel(AlbertPreTrainedModel): class AlbertModel(AlbertPreTrainedModel):
r""" r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
...@@ -500,8 +512,15 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -500,8 +512,15 @@ class AlbertModel(AlbertPreTrainedModel):
inner_group_idx = int(layer - group_idx * self.config.inner_group_num) inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads) self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(
inputs_embeds=None): self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -520,31 +539,37 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -520,31 +539,37 @@ class AlbertModel(AlbertPreTrainedModel):
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
if head_mask is not None: if head_mask is not None:
if head_mask.dim() == 1: if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2: elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer head_mask = (
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
) # We can specify head_mask for each layer
head_mask = head_mask.to(
dtype=next(self.parameters()).dtype
) # switch to fload if need + fp16 compatibility
else: else:
head_mask = [None] * self.config.num_hidden_layers head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, embedding_output = self.embeddings(
inputs_embeds=inputs_embeds) input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
encoder_outputs = self.encoder(embedding_output, )
extended_attention_mask, encoder_outputs = self.encoder(embedding_output, extended_attention_mask, head_mask=head_mask)
head_mask=head_mask)
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0]))
outputs = (sequence_output, pooled_output) + encoder_outputs[1:] # add hidden_states and attentions if they are here outputs = (sequence_output, pooled_output) + encoder_outputs[
1:
] # add hidden_states and attentions if they are here
return outputs return outputs
class AlbertMLMHead(nn.Module): class AlbertMLMHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super(AlbertMLMHead, self).__init__() super(AlbertMLMHead, self).__init__()
...@@ -566,7 +591,9 @@ class AlbertMLMHead(nn.Module): ...@@ -566,7 +591,9 @@ class AlbertMLMHead(nn.Module):
return prediction_scores return prediction_scores
@add_start_docstrings("Bert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING) @add_start_docstrings(
"Bert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING
)
class AlbertForMaskedLM(AlbertPreTrainedModel): class AlbertForMaskedLM(AlbertPreTrainedModel):
r""" r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
...@@ -602,21 +629,28 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): ...@@ -602,21 +629,28 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
""" Make sure we are sharing the input and output embeddings. """ Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead. Export to TorchScript can't handle parameter sharing so we are cloning them instead.
""" """
self._tie_or_clone_weights(self.predictions.decoder, self._tie_or_clone_weights(self.predictions.decoder, self.albert.embeddings.word_embeddings)
self.albert.embeddings.word_embeddings)
def get_output_embeddings(self): def get_output_embeddings(self):
return self.predictions.decoder return self.predictions.decoder
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, def forward(
masked_lm_labels=None): self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
masked_lm_labels=None,
):
outputs = self.albert( outputs = self.albert(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds inputs_embeds=inputs_embeds,
) )
sequence_outputs = outputs[0] sequence_outputs = outputs[0]
...@@ -631,9 +665,12 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): ...@@ -631,9 +665,12 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
return outputs return outputs
@add_start_docstrings("""Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of @add_start_docstrings(
"""Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """, the pooled output) e.g. for GLUE tasks. """,
ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING) ALBERT_START_DOCSTRING,
ALBERT_INPUTS_DOCSTRING,
)
class AlbertForSequenceClassification(AlbertPreTrainedModel): class AlbertForSequenceClassification(AlbertPreTrainedModel):
r""" r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
...@@ -665,6 +702,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): ...@@ -665,6 +702,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
loss, logits = outputs[:2] loss, logits = outputs[:2]
""" """
def __init__(self, config): def __init__(self, config):
super(AlbertForSequenceClassification, self).__init__(config) super(AlbertForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -675,8 +713,16 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): ...@@ -675,8 +713,16 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, def forward(
position_ids=None, head_mask=None, inputs_embeds=None, labels=None): self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
outputs = self.albert( outputs = self.albert(
input_ids=input_ids, input_ids=input_ids,
...@@ -684,7 +730,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): ...@@ -684,7 +730,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds inputs_embeds=inputs_embeds,
) )
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -707,10 +753,12 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): ...@@ -707,10 +753,12 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
return outputs # (loss), logits, (hidden_states), (attentions) return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings(
@add_start_docstrings("""Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of """Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """, the hidden-states output to compute `span start logits` and `span end logits`). """,
ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING) ALBERT_START_DOCSTRING,
ALBERT_INPUTS_DOCSTRING,
)
class AlbertForQuestionAnswering(AlbertPreTrainedModel): class AlbertForQuestionAnswering(AlbertPreTrainedModel):
r""" r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
...@@ -752,6 +800,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel): ...@@ -752,6 +800,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
""" """
def __init__(self, config): def __init__(self, config):
super(AlbertForQuestionAnswering, self).__init__(config) super(AlbertForQuestionAnswering, self).__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -761,8 +810,17 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel): ...@@ -761,8 +810,17 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(
inputs_embeds=None, start_positions=None, end_positions=None): self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
):
outputs = self.albert( outputs = self.albert(
input_ids=input_ids, input_ids=input_ids,
...@@ -770,7 +828,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel): ...@@ -770,7 +828,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds inputs_embeds=inputs_embeds,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
......
...@@ -18,31 +18,87 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -18,31 +18,87 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import logging import logging
from .configuration_auto import (AlbertConfig, BertConfig, CamembertConfig, CTRLConfig, from .configuration_auto import (
DistilBertConfig, GPT2Config, OpenAIGPTConfig, RobertaConfig, AlbertConfig,
TransfoXLConfig, XLMConfig, XLNetConfig, XLMRobertaConfig) BertConfig,
CamembertConfig,
from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering, \ CTRLConfig,
BertForTokenClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP DistilBertConfig,
GPT2Config,
OpenAIGPTConfig,
RobertaConfig,
TransfoXLConfig,
XLMConfig,
XLNetConfig,
XLMRobertaConfig,
)
from .modeling_bert import (
BertModel,
BertForMaskedLM,
BertForSequenceClassification,
BertForQuestionAnswering,
BertForTokenClassification,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_openai import OpenAIGPTModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_openai import OpenAIGPTModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_gpt2 import GPT2Model, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_gpt2 import GPT2Model, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_ctrl import CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_ctrl import CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering, \ from .modeling_xlnet import (
XLNetForTokenClassification, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP XLNetModel,
from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering, \ XLNetLMHeadModel,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP XLNetForSequenceClassification,
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, \ XLNetForQuestionAnswering,
RobertaForTokenClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP XLNetForTokenClassification,
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, \ XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForSequenceClassification, DistilBertForTokenClassification, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP )
from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertForSequenceClassification, \ from .modeling_xlm import (
CamembertForMultipleChoice, CamembertForTokenClassification, CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP XLMModel,
from .modeling_albert import AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, \ XLMWithLMHeadModel,
AlbertForQuestionAnswering, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP XLMForSequenceClassification,
XLMForQuestionAnswering,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_roberta import (
RobertaModel,
RobertaForMaskedLM,
RobertaForSequenceClassification,
RobertaForTokenClassification,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_distilbert import (
DistilBertModel,
DistilBertForQuestionAnswering,
DistilBertForMaskedLM,
DistilBertForSequenceClassification,
DistilBertForTokenClassification,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_camembert import (
CamembertModel,
CamembertForMaskedLM,
CamembertForSequenceClassification,
CamembertForMultipleChoice,
CamembertForTokenClassification,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_albert import (
AlbertModel,
AlbertForMaskedLM,
AlbertForSequenceClassification,
AlbertForQuestionAnswering,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_t5 import T5Model, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_t5 import T5Model, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_xlm_roberta import XLMRobertaModel, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification, \ from .modeling_xlm_roberta import (
XLMRobertaForMultipleChoice, XLMRobertaForTokenClassification, XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP XLMRobertaModel,
XLMRobertaForMaskedLM,
XLMRobertaForSequenceClassification,
XLMRobertaForMultipleChoice,
XLMRobertaForTokenClassification,
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_utils import PreTrainedModel, SequenceSummary from .modeling_utils import PreTrainedModel, SequenceSummary
...@@ -51,7 +107,8 @@ from .file_utils import add_start_docstrings ...@@ -51,7 +107,8 @@ from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict((key, value) ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
(key, value)
for pretrained_map in [ for pretrained_map in [
BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
...@@ -66,8 +123,9 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict((key, value) ...@@ -66,8 +123,9 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict((key, value)
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP, CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
] ]
for key, value, in pretrained_map.items()) for key, value, in pretrained_map.items()
)
class AutoModel(object): class AutoModel(object):
...@@ -98,10 +156,13 @@ class AutoModel(object): ...@@ -98,10 +156,13 @@ class AutoModel(object):
This class cannot be instantiated using `__init__()` (throws an error). This class cannot be instantiated using `__init__()` (throws an error).
""" """
def __init__(self): def __init__(self):
raise EnvironmentError("AutoModel is designed to be instantiated " raise EnvironmentError(
"AutoModel is designed to be instantiated "
"using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` or " "using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` or "
"`AutoModel.from_config(config)` methods.") "`AutoModel.from_config(config)` methods."
)
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
...@@ -232,35 +293,39 @@ class AutoModel(object): ...@@ -232,35 +293,39 @@ class AutoModel(object):
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if 't5' in pretrained_model_name_or_path: if "t5" in pretrained_model_name_or_path:
return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path: elif "distilbert" in pretrained_model_name_or_path:
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'albert' in pretrained_model_name_or_path: elif "albert" in pretrained_model_name_or_path:
return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'camembert' in pretrained_model_name_or_path: elif "camembert" in pretrained_model_name_or_path:
return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm-roberta' in pretrained_model_name_or_path: elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif "roberta" in pretrained_model_name_or_path:
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif "bert" in pretrained_model_name_or_path:
return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'openai-gpt' in pretrained_model_name_or_path: elif "openai-gpt" in pretrained_model_name_or_path:
return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'gpt2' in pretrained_model_name_or_path: elif "gpt2" in pretrained_model_name_or_path:
return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'transfo-xl' in pretrained_model_name_or_path: elif "transfo-xl" in pretrained_model_name_or_path:
return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path: elif "xlnet" in pretrained_model_name_or_path:
return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm' in pretrained_model_name_or_path: elif "xlm" in pretrained_model_name_or_path:
return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'ctrl' in pretrained_model_name_or_path: elif "ctrl" in pretrained_model_name_or_path:
return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError(
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "Unrecognized model identifier in {}. Should contains one of "
"'xlm-roberta', 'xlm', 'roberta, 'ctrl', 'distilbert', 'camembert', 'albert'".format(pretrained_model_name_or_path)) "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta, 'ctrl', 'distilbert', 'camembert', 'albert'".format(
pretrained_model_name_or_path
)
)
class AutoModelWithLMHead(object): class AutoModelWithLMHead(object):
...@@ -291,10 +356,13 @@ class AutoModelWithLMHead(object): ...@@ -291,10 +356,13 @@ class AutoModelWithLMHead(object):
This class cannot be instantiated using `__init__()` (throws an error). This class cannot be instantiated using `__init__()` (throws an error).
""" """
def __init__(self): def __init__(self):
raise EnvironmentError("AutoModelWithLMHead is designed to be instantiated " raise EnvironmentError(
"AutoModelWithLMHead is designed to be instantiated "
"using the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` or " "using the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` or "
"`AutoModelWithLMHead.from_config(config)` methods.") "`AutoModelWithLMHead.from_config(config)` methods."
)
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
...@@ -423,35 +491,39 @@ class AutoModelWithLMHead(object): ...@@ -423,35 +491,39 @@ class AutoModelWithLMHead(object):
model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if 't5' in pretrained_model_name_or_path: if "t5" in pretrained_model_name_or_path:
return T5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return T5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path: elif "distilbert" in pretrained_model_name_or_path:
return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'albert' in pretrained_model_name_or_path: elif "albert" in pretrained_model_name_or_path:
return AlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return AlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'camembert' in pretrained_model_name_or_path: elif "camembert" in pretrained_model_name_or_path:
return CamembertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return CamembertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm-roberta' in pretrained_model_name_or_path: elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return XLMRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif "roberta" in pretrained_model_name_or_path:
return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif "bert" in pretrained_model_name_or_path:
return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'openai-gpt' in pretrained_model_name_or_path: elif "openai-gpt" in pretrained_model_name_or_path:
return OpenAIGPTLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return OpenAIGPTLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'gpt2' in pretrained_model_name_or_path: elif "gpt2" in pretrained_model_name_or_path:
return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'transfo-xl' in pretrained_model_name_or_path: elif "transfo-xl" in pretrained_model_name_or_path:
return TransfoXLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TransfoXLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path: elif "xlnet" in pretrained_model_name_or_path:
return XLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return XLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm' in pretrained_model_name_or_path: elif "xlm" in pretrained_model_name_or_path:
return XLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return XLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'ctrl' in pretrained_model_name_or_path: elif "ctrl" in pretrained_model_name_or_path:
return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError(
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "Unrecognized model identifier in {}. Should contains one of "
"'xlm-roberta', 'xlm', 'roberta','ctrl', 'distilbert', 'camembert', 'albert'".format(pretrained_model_name_or_path)) "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta','ctrl', 'distilbert', 'camembert', 'albert'".format(
pretrained_model_name_or_path
)
)
class AutoModelForSequenceClassification(object): class AutoModelForSequenceClassification(object):
...@@ -477,10 +549,13 @@ class AutoModelForSequenceClassification(object): ...@@ -477,10 +549,13 @@ class AutoModelForSequenceClassification(object):
This class cannot be instantiated using `__init__()` (throws an error). This class cannot be instantiated using `__init__()` (throws an error).
""" """
def __init__(self): def __init__(self):
raise EnvironmentError("AutoModelForSequenceClassification is designed to be instantiated " raise EnvironmentError(
"AutoModelForSequenceClassification is designed to be instantiated "
"using the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)` or " "using the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)` or "
"`AutoModelForSequenceClassification.from_config(config)` methods.") "`AutoModelForSequenceClassification.from_config(config)` methods."
)
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
...@@ -597,25 +672,39 @@ class AutoModelForSequenceClassification(object): ...@@ -597,25 +672,39 @@ class AutoModelForSequenceClassification(object):
model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if 'distilbert' in pretrained_model_name_or_path: if "distilbert" in pretrained_model_name_or_path:
return DistilBertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return DistilBertForSequenceClassification.from_pretrained(
elif 'albert' in pretrained_model_name_or_path: pretrained_model_name_or_path, *model_args, **kwargs
return AlbertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) )
elif 'camembert' in pretrained_model_name_or_path: elif "albert" in pretrained_model_name_or_path:
return CamembertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return AlbertForSequenceClassification.from_pretrained(
elif 'xlm-roberta' in pretrained_model_name_or_path: pretrained_model_name_or_path, *model_args, **kwargs
return XLMRobertaForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) )
elif 'roberta' in pretrained_model_name_or_path: elif "camembert" in pretrained_model_name_or_path:
return RobertaForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return CamembertForSequenceClassification.from_pretrained(
elif 'bert' in pretrained_model_name_or_path: pretrained_model_name_or_path, *model_args, **kwargs
)
elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
elif "roberta" in pretrained_model_name_or_path:
return RobertaForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
elif "bert" in pretrained_model_name_or_path:
return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path: elif "xlnet" in pretrained_model_name_or_path:
return XLNetForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return XLNetForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm' in pretrained_model_name_or_path: elif "xlm" in pretrained_model_name_or_path:
return XLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return XLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError(
"'bert', 'xlnet', 'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'albert'".format(pretrained_model_name_or_path)) "Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'albert'".format(
pretrained_model_name_or_path
)
)
class AutoModelForQuestionAnswering(object): class AutoModelForQuestionAnswering(object):
...@@ -638,10 +727,13 @@ class AutoModelForQuestionAnswering(object): ...@@ -638,10 +727,13 @@ class AutoModelForQuestionAnswering(object):
This class cannot be instantiated using `__init__()` (throws an error). This class cannot be instantiated using `__init__()` (throws an error).
""" """
def __init__(self): def __init__(self):
raise EnvironmentError("AutoModelForQuestionAnswering is designed to be instantiated " raise EnvironmentError(
"AutoModelForQuestionAnswering is designed to be instantiated "
"using the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or " "using the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or "
"`AutoModelForQuestionAnswering.from_config(config)` methods.") "`AutoModelForQuestionAnswering.from_config(config)` methods."
)
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
...@@ -745,26 +837,30 @@ class AutoModelForQuestionAnswering(object): ...@@ -745,26 +837,30 @@ class AutoModelForQuestionAnswering(object):
model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if 'distilbert' in pretrained_model_name_or_path: if "distilbert" in pretrained_model_name_or_path:
return DistilBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return DistilBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'albert' in pretrained_model_name_or_path: elif "albert" in pretrained_model_name_or_path:
return AlbertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return AlbertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif "bert" in pretrained_model_name_or_path:
return BertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return BertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path: elif "xlnet" in pretrained_model_name_or_path:
return XLNetForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return XLNetForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm' in pretrained_model_name_or_path: elif "xlm" in pretrained_model_name_or_path:
return XLMForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return XLMForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError(
"'bert', 'xlnet', 'xlm', 'distilbert', 'albert'".format(pretrained_model_name_or_path)) "Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'xlm', 'distilbert', 'albert'".format(pretrained_model_name_or_path)
)
class AutoModelForTokenClassification: class AutoModelForTokenClassification:
def __init__(self): def __init__(self):
raise EnvironmentError("AutoModelForTokenClassification is designed to be instantiated " raise EnvironmentError(
"using the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)` or " "AutoModelForTokenClassification is designed to be instantiated "
"`AutoModelForTokenClassification.from_config(config)` methods.") "using the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)` or "
"`AutoModelForTokenClassification.from_config(config)` methods."
)
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
...@@ -797,7 +893,7 @@ class AutoModelForTokenClassification: ...@@ -797,7 +893,7 @@ class AutoModelForTokenClassification:
elif isinstance(config, XLMRobertaConfig): elif isinstance(config, XLMRobertaConfig):
return XLMRobertaForTokenClassification(config) return XLMRobertaForTokenClassification(config)
raise ValueError("Unrecognized configuration class {}".format(config)) raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the question answering model classes of the library r""" Instantiates one of the question answering model classes of the library
...@@ -870,18 +966,28 @@ class AutoModelForTokenClassification: ...@@ -870,18 +966,28 @@ class AutoModelForTokenClassification:
model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if 'camembert' in pretrained_model_name_or_path: if "camembert" in pretrained_model_name_or_path:
return CamembertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return CamembertForTokenClassification.from_pretrained(
elif 'distilbert' in pretrained_model_name_or_path: pretrained_model_name_or_path, *model_args, **kwargs
return DistilBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) )
elif 'xlm-roberta' in pretrained_model_name_or_path: elif "distilbert" in pretrained_model_name_or_path:
return XLMRobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return DistilBertForTokenClassification.from_pretrained(
elif 'roberta' in pretrained_model_name_or_path: pretrained_model_name_or_path, *model_args, **kwargs
)
elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
elif "roberta" in pretrained_model_name_or_path:
return RobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return RobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif "bert" in pretrained_model_name_or_path:
return BertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return BertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path: elif "xlnet" in pretrained_model_name_or_path:
return XLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return XLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError(
"'bert', 'xlnet', 'camembert', 'distilbert', 'xlm-roberta', 'roberta'".format(pretrained_model_name_or_path)) "Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'camembert', 'distilbert', 'xlm-roberta', 'roberta'".format(
pretrained_model_name_or_path
)
)
...@@ -33,27 +33,27 @@ from .file_utils import add_start_docstrings ...@@ -33,27 +33,27 @@ from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin", "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin", "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin", "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin", "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin", "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin", "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin", "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin", "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin", "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin", "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin", "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin", "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin", "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-pytorch_model.bin", "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-pytorch_model.bin",
'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-pytorch_model.bin", "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-pytorch_model.bin",
'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-pytorch_model.bin", "bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-pytorch_model.bin",
'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-pytorch_model.bin", "bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-pytorch_model.bin",
'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-pytorch_model.bin", "bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-pytorch_model.bin",
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-pytorch_model.bin", "bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-pytorch_model.bin",
'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/pytorch_model.bin", "bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/pytorch_model.bin",
'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/pytorch_model.bin", "bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/pytorch_model.bin",
} }
...@@ -65,8 +65,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): ...@@ -65,8 +65,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
except ImportError: except ImportError:
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " logger.error(
"https://www.tensorflow.org/install/ for installation instructions.") "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise raise
tf_path = os.path.abspath(tf_checkpoint_path) tf_path = os.path.abspath(tf_checkpoint_path)
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
...@@ -81,7 +83,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): ...@@ -81,7 +83,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
arrays.append(array) arrays.append(array)
for name, array in zip(names, arrays): for name, array in zip(names, arrays):
name = name.split('/') name = name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model # which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name): if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
...@@ -89,18 +91,18 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): ...@@ -89,18 +91,18 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
continue continue
pointer = model pointer = model
for m_name in name: for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name): if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
l = re.split(r'_(\d+)', m_name) l = re.split(r"_(\d+)", m_name)
else: else:
l = [m_name] l = [m_name]
if l[0] == 'kernel' or l[0] == 'gamma': if l[0] == "kernel" or l[0] == "gamma":
pointer = getattr(pointer, 'weight') pointer = getattr(pointer, "weight")
elif l[0] == 'output_bias' or l[0] == 'beta': elif l[0] == "output_bias" or l[0] == "beta":
pointer = getattr(pointer, 'bias') pointer = getattr(pointer, "bias")
elif l[0] == 'output_weights': elif l[0] == "output_weights":
pointer = getattr(pointer, 'weight') pointer = getattr(pointer, "weight")
elif l[0] == 'squad': elif l[0] == "squad":
pointer = getattr(pointer, 'classifier') pointer = getattr(pointer, "classifier")
else: else:
try: try:
pointer = getattr(pointer, l[0]) pointer = getattr(pointer, l[0])
...@@ -110,9 +112,9 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): ...@@ -110,9 +112,9 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
if len(l) >= 2: if len(l) >= 2:
num = int(l[1]) num = int(l[1])
pointer = pointer[num] pointer = pointer[num]
if m_name[-11:] == '_embeddings': if m_name[-11:] == "_embeddings":
pointer = getattr(pointer, 'weight') pointer = getattr(pointer, "weight")
elif m_name == 'kernel': elif m_name == "kernel":
array = np.transpose(array) array = np.transpose(array)
try: try:
assert pointer.shape == array.shape assert pointer.shape == array.shape
...@@ -157,6 +159,7 @@ BertLayerNorm = torch.nn.LayerNorm ...@@ -157,6 +159,7 @@ BertLayerNorm = torch.nn.LayerNorm
class BertEmbeddings(nn.Module): class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings. """Construct the embeddings from word, position and token_type embeddings.
""" """
def __init__(self, config): def __init__(self, config):
super(BertEmbeddings, self).__init__() super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
...@@ -199,7 +202,8 @@ class BertSelfAttention(nn.Module): ...@@ -199,7 +202,8 @@ class BertSelfAttention(nn.Module):
if config.hidden_size % config.num_attention_heads != 0: if config.hidden_size % config.num_attention_heads != 0:
raise ValueError( raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention " "The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)) "heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
...@@ -217,7 +221,14 @@ class BertSelfAttention(nn.Module): ...@@ -217,7 +221,14 @@ class BertSelfAttention(nn.Module):
x = x.view(*new_x_shape) x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None): def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
...@@ -307,8 +318,17 @@ class BertAttention(nn.Module): ...@@ -307,8 +318,17 @@ class BertAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None): def forward(
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask) self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
self_outputs = self.self(
hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask
)
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs return outputs
...@@ -353,13 +373,22 @@ class BertLayer(nn.Module): ...@@ -353,13 +373,22 @@ class BertLayer(nn.Module):
self.intermediate = BertIntermediate(config) self.intermediate = BertIntermediate(config)
self.output = BertOutput(config) self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None): def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask) self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask) cross_attention_outputs = self.crossattention(
attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask
)
attention_output = cross_attention_outputs[0] attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
...@@ -376,14 +405,23 @@ class BertEncoder(nn.Module): ...@@ -376,14 +405,23 @@ class BertEncoder(nn.Module):
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None): def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
all_hidden_states = () all_hidden_states = ()
all_attentions = () all_attentions = ()
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask) layer_outputs = layer_module(
hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if self.output_attentions: if self.output_attentions:
...@@ -440,9 +478,7 @@ class BertLMPredictionHead(nn.Module): ...@@ -440,9 +478,7 @@ class BertLMPredictionHead(nn.Module):
# The output weights are the same as the input embeddings, but there is # The output weights are the same as the input embeddings, but there is
# an output-only bias for each token. # an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
config.vocab_size,
bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) self.bias = nn.Parameter(torch.zeros(config.vocab_size))
...@@ -488,6 +524,7 @@ class BertPreTrainedModel(PreTrainedModel): ...@@ -488,6 +524,7 @@ class BertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = BertConfig config_class = BertConfig
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_bert load_tf_weights = load_tf_weights_in_bert
...@@ -581,8 +618,12 @@ BERT_INPUTS_DOCSTRING = r""" ...@@ -581,8 +618,12 @@ BERT_INPUTS_DOCSTRING = r"""
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
""" """
@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) @add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING,
)
class BertModel(BertPreTrainedModel): class BertModel(BertPreTrainedModel):
r""" r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
...@@ -612,6 +653,7 @@ class BertModel(BertPreTrainedModel): ...@@ -612,6 +653,7 @@ class BertModel(BertPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
def __init__(self, config): def __init__(self, config):
super(BertModel, self).__init__(config) super(BertModel, self).__init__(config)
self.config = config self.config = config
...@@ -636,8 +678,17 @@ class BertModel(BertPreTrainedModel): ...@@ -636,8 +678,17 @@ class BertModel(BertPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, def forward(
head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None): self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
""" Forward pass on the Model. """ Forward pass on the Model.
The model can behave as an encoder (with only self-attention) as well The model can behave as an encoder (with only self-attention) as well
...@@ -681,12 +732,18 @@ class BertModel(BertPreTrainedModel): ...@@ -681,12 +732,18 @@ class BertModel(BertPreTrainedModel):
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device) seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
causal_mask = causal_mask.to(torch.long) # not converting to long will cause errors with pytorch version < 1.3 causal_mask = causal_mask.to(
torch.long
) # not converting to long will cause errors with pytorch version < 1.3
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else: else:
extended_attention_mask = attention_mask[:, None, None, :] extended_attention_mask = attention_mask[:, None, None, :]
else: else:
raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape)) raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
...@@ -709,10 +766,15 @@ class BertModel(BertPreTrainedModel): ...@@ -709,10 +766,15 @@ class BertModel(BertPreTrainedModel):
elif encoder_attention_mask.dim() == 2: elif encoder_attention_mask.dim() == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
else: else:
raise ValueError("Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(encoder_hidden_shape, raise ValueError(
encoder_attention_mask.shape)) "Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(
encoder_hidden_shape, encoder_attention_mask.shape
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility )
)
encoder_extended_attention_mask = encoder_extended_attention_mask.to(
dtype=next(self.parameters()).dtype
) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
else: else:
encoder_extended_attention_mask = None encoder_extended_attention_mask = None
...@@ -727,28 +789,40 @@ class BertModel(BertPreTrainedModel): ...@@ -727,28 +789,40 @@ class BertModel(BertPreTrainedModel):
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2: elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer head_mask = (
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
) # We can specify head_mask for each layer
head_mask = head_mask.to(
dtype=next(self.parameters()).dtype
) # switch to fload if need + fp16 compatibility
else: else:
head_mask = [None] * self.config.num_hidden_layers head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds) embedding_output = self.embeddings(
encoder_outputs = self.encoder(embedding_output, input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
attention_mask=extended_attention_mask, )
head_mask=head_mask, encoder_outputs = self.encoder(
encoder_hidden_states=encoder_hidden_states, embedding_output,
encoder_attention_mask=encoder_extended_attention_mask) attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
)
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here outputs = (sequence_output, pooled_output,) + encoder_outputs[
1:
] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions) return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training: @add_start_docstrings(
"""Bert Model with two heads on top as done during the pre-training:
a `masked language modeling` head and a `next sentence prediction (classification)` head. """, a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING) BERT_INPUTS_DOCSTRING,
)
class BertForPreTraining(BertPreTrainedModel): class BertForPreTraining(BertPreTrainedModel):
r""" r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
...@@ -786,6 +860,7 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -786,6 +860,7 @@ class BertForPreTraining(BertPreTrainedModel):
prediction_scores, seq_relationship_scores = outputs[:2] prediction_scores, seq_relationship_scores = outputs[:2]
""" """
def __init__(self, config): def __init__(self, config):
super(BertForPreTraining, self).__init__(config) super(BertForPreTraining, self).__init__(config)
...@@ -797,20 +872,33 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -797,20 +872,33 @@ class BertForPreTraining(BertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, def forward(
masked_lm_labels=None, next_sentence_label=None): self,
input_ids=None,
outputs = self.bert(input_ids, attention_mask=None,
attention_mask=attention_mask, token_type_ids=None,
token_type_ids=token_type_ids, position_ids=None,
position_ids=position_ids, head_mask=None,
head_mask=head_mask, inputs_embeds=None,
inputs_embeds=inputs_embeds) masked_lm_labels=None,
next_sentence_label=None,
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here outputs = (prediction_scores, seq_relationship_score,) + outputs[
2:
] # add hidden states and attention if they are here
if masked_lm_labels is not None and next_sentence_label is not None: if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
...@@ -822,9 +910,9 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -822,9 +910,9 @@ class BertForPreTraining(BertPreTrainedModel):
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions) return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, @add_start_docstrings(
BERT_START_DOCSTRING, """Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING
BERT_INPUTS_DOCSTRING) )
class BertForMaskedLM(BertPreTrainedModel): class BertForMaskedLM(BertPreTrainedModel):
r""" r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
...@@ -862,6 +950,7 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -862,6 +950,7 @@ class BertForMaskedLM(BertPreTrainedModel):
loss, prediction_scores = outputs[:2] loss, prediction_scores = outputs[:2]
""" """
def __init__(self, config): def __init__(self, config):
super(BertForMaskedLM, self).__init__(config) super(BertForMaskedLM, self).__init__(config)
...@@ -873,17 +962,30 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -873,17 +962,30 @@ class BertForMaskedLM(BertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, def forward(
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ): self,
input_ids=None,
outputs = self.bert(input_ids, attention_mask=None,
attention_mask=attention_mask, token_type_ids=None,
token_type_ids=token_type_ids, position_ids=None,
position_ids=position_ids, head_mask=None,
head_mask=head_mask, inputs_embeds=None,
inputs_embeds=inputs_embeds, masked_lm_labels=None,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=None,
encoder_attention_mask=encoder_attention_mask) encoder_attention_mask=None,
lm_labels=None,
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output) prediction_scores = self.cls(sequence_output)
...@@ -912,9 +1014,11 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -912,9 +1014,11 @@ class BertForMaskedLM(BertPreTrainedModel):
return outputs # (masked_lm_loss), (ltr_lm_loss), prediction_scores, (hidden_states), (attentions) return outputs # (masked_lm_loss), (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """, @add_start_docstrings(
BERT_START_DOCSTRING, """Bert Model with a `next sentence prediction (classification)` head on top. """,
BERT_INPUTS_DOCSTRING) BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING,
)
class BertForNextSentencePrediction(BertPreTrainedModel): class BertForNextSentencePrediction(BertPreTrainedModel):
r""" r"""
**next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
...@@ -945,6 +1049,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -945,6 +1049,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_scores = outputs[0] seq_relationship_scores = outputs[0]
""" """
def __init__(self, config): def __init__(self, config):
super(BertForNextSentencePrediction, self).__init__(config) super(BertForNextSentencePrediction, self).__init__(config)
...@@ -953,15 +1058,25 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -953,15 +1058,25 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, def forward(
next_sentence_label=None): self,
input_ids=None,
outputs = self.bert(input_ids, attention_mask=None,
attention_mask=attention_mask, token_type_ids=None,
token_type_ids=token_type_ids, position_ids=None,
position_ids=position_ids, head_mask=None,
head_mask=head_mask, inputs_embeds=None,
inputs_embeds=inputs_embeds) next_sentence_label=None,
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -976,10 +1091,12 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -976,10 +1091,12 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions) return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of @add_start_docstrings(
"""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """, the pooled output) e.g. for GLUE tasks. """,
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING) BERT_INPUTS_DOCSTRING,
)
class BertForSequenceClassification(BertPreTrainedModel): class BertForSequenceClassification(BertPreTrainedModel):
r""" r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
...@@ -1011,6 +1128,7 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -1011,6 +1128,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
loss, logits = outputs[:2] loss, logits = outputs[:2]
""" """
def __init__(self, config): def __init__(self, config):
super(BertForSequenceClassification, self).__init__(config) super(BertForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1021,15 +1139,25 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -1021,15 +1139,25 @@ class BertForSequenceClassification(BertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, def forward(
position_ids=None, head_mask=None, inputs_embeds=None, labels=None): self,
input_ids=None,
outputs = self.bert(input_ids, attention_mask=None,
attention_mask=attention_mask, token_type_ids=None,
token_type_ids=token_type_ids, position_ids=None,
position_ids=position_ids, head_mask=None,
head_mask=head_mask, inputs_embeds=None,
inputs_embeds=inputs_embeds) labels=None,
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -1051,10 +1179,12 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -1051,10 +1179,12 @@ class BertForSequenceClassification(BertPreTrainedModel):
return outputs # (loss), logits, (hidden_states), (attentions) return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of @add_start_docstrings(
"""Bert Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING) BERT_INPUTS_DOCSTRING,
)
class BertForMultipleChoice(BertPreTrainedModel): class BertForMultipleChoice(BertPreTrainedModel):
r""" r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
...@@ -1087,6 +1217,7 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1087,6 +1217,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
loss, classification_scores = outputs[:2] loss, classification_scores = outputs[:2]
""" """
def __init__(self, config): def __init__(self, config):
super(BertForMultipleChoice, self).__init__(config) super(BertForMultipleChoice, self).__init__(config)
...@@ -1096,8 +1227,16 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1096,8 +1227,16 @@ class BertForMultipleChoice(BertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, def forward(
position_ids=None, head_mask=None, inputs_embeds=None, labels=None): self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
num_choices = input_ids.shape[1] num_choices = input_ids.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) input_ids = input_ids.view(-1, input_ids.size(-1))
...@@ -1105,12 +1244,14 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1105,12 +1244,14 @@ class BertForMultipleChoice(BertPreTrainedModel):
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
outputs = self.bert(input_ids, outputs = self.bert(
attention_mask=attention_mask, input_ids,
token_type_ids=token_type_ids, attention_mask=attention_mask,
position_ids=position_ids, token_type_ids=token_type_ids,
head_mask=head_mask, position_ids=position_ids,
inputs_embeds=inputs_embeds) head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -1128,10 +1269,12 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1128,10 +1269,12 @@ class BertForMultipleChoice(BertPreTrainedModel):
return outputs # (loss), reshaped_logits, (hidden_states), (attentions) return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of @add_start_docstrings(
"""Bert Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING) BERT_INPUTS_DOCSTRING,
)
class BertForTokenClassification(BertPreTrainedModel): class BertForTokenClassification(BertPreTrainedModel):
r""" r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
...@@ -1161,6 +1304,7 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1161,6 +1304,7 @@ class BertForTokenClassification(BertPreTrainedModel):
loss, scores = outputs[:2] loss, scores = outputs[:2]
""" """
def __init__(self, config): def __init__(self, config):
super(BertForTokenClassification, self).__init__(config) super(BertForTokenClassification, self).__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1171,15 +1315,25 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1171,15 +1315,25 @@ class BertForTokenClassification(BertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, def forward(
position_ids=None, head_mask=None, inputs_embeds=None, labels=None): self,
input_ids=None,
outputs = self.bert(input_ids, attention_mask=None,
attention_mask=attention_mask, token_type_ids=None,
token_type_ids=token_type_ids, position_ids=None,
position_ids=position_ids, head_mask=None,
head_mask=head_mask, inputs_embeds=None,
inputs_embeds=inputs_embeds) labels=None,
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -1202,10 +1356,12 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1202,10 +1356,12 @@ class BertForTokenClassification(BertPreTrainedModel):
return outputs # (loss), scores, (hidden_states), (attentions) return outputs # (loss), scores, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of @add_start_docstrings(
"""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """, the hidden-states output to compute `span start logits` and `span end logits`). """,
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING) BERT_INPUTS_DOCSTRING,
)
class BertForQuestionAnswering(BertPreTrainedModel): class BertForQuestionAnswering(BertPreTrainedModel):
r""" r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
...@@ -1247,6 +1403,7 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1247,6 +1403,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
""" """
def __init__(self, config): def __init__(self, config):
super(BertForQuestionAnswering, self).__init__(config) super(BertForQuestionAnswering, self).__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1256,15 +1413,26 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1256,15 +1413,26 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, def forward(
start_positions=None, end_positions=None): self,
input_ids=None,
outputs = self.bert(input_ids, attention_mask=None,
attention_mask=attention_mask, token_type_ids=None,
token_type_ids=token_type_ids, position_ids=None,
position_ids=position_ids, head_mask=None,
head_mask=head_mask, inputs_embeds=None,
inputs_embeds=inputs_embeds) start_positions=None,
end_positions=None,
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
sequence_output = outputs[0] sequence_output = outputs[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment