Commit fa84ae26 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Reformat source code with black.

This is the result of:

    $ black --line-length 119 examples templates transformers utils hubconf.py setup.py

There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.

This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
parent 63e3827c
......@@ -24,82 +24,270 @@ import tensorflow as tf
from transformers import is_torch_available, cached_path
from transformers import (load_pytorch_checkpoint_in_tf2_model,
BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2Config, TFGPT2LMHeadModel, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNetConfig, TFXLNetLMHeadModel, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLMConfig, TFXLMWithLMHeadModel, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, TFDistilBertForSequenceClassification, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
AlbertConfig, TFAlbertForMaskedLM, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5Config, TFT5WithLMHeadModel, T5_PRETRAINED_CONFIG_ARCHIVE_MAP)
from transformers import (
load_pytorch_checkpoint_in_tf2_model,
BertConfig,
TFBertForPreTraining,
TFBertForQuestionAnswering,
TFBertForSequenceClassification,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2Config,
TFGPT2LMHeadModel,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNetConfig,
TFXLNetLMHeadModel,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLMConfig,
TFXLMWithLMHeadModel,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
TransfoXLConfig,
TFTransfoXLLMHeadModel,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
OpenAIGPTConfig,
TFOpenAIGPTLMHeadModel,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
RobertaConfig,
TFRobertaForMaskedLM,
TFRobertaForSequenceClassification,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
DistilBertConfig,
TFDistilBertForMaskedLM,
TFDistilBertForQuestionAnswering,
TFDistilBertForSequenceClassification,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRLConfig,
TFCTRLLMHeadModel,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
AlbertConfig,
TFAlbertForMaskedLM,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5Config,
TFT5WithLMHeadModel,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
)
if is_torch_available():
import torch
import numpy as np
from transformers import (BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP)
from transformers import (
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNetLMHeadModel,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLMWithLMHeadModel,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
TransfoXLLMHeadModel,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM,
RobertaForSequenceClassification,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM,
DistilBertForQuestionAnswering,
DistilBertForSequenceClassification,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5WithLMHeadModel,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
)
else:
(BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForSequenceClassification, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP) = (
None, None, None, None,
None, None,
None, None,
None, None,
None, None,
None, None,
None, None, None,
None, None, None, None,
None, None,
None, None,
None, None)
(
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNetLMHeadModel,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLMWithLMHeadModel,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
TransfoXLLMHeadModel,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM,
RobertaForSequenceClassification,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM,
DistilBertForSequenceClassification,
DistilBertForQuestionAnswering,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5WithLMHeadModel,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
) = (
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
import logging
logging.basicConfig(level=logging.INFO)
MODEL_CLASSES = {
'bert': (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'gpt2': (GPT2Config, TFGPT2LMHeadModel, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP),
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP),
'xlm': (XLMConfig, TFXLMWithLMHeadModel, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP),
'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP),
'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'roberta': (RobertaConfig, TFRobertaForMaskedLM, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP),
'albert': (AlbertConfig, TFAlbertForMaskedLM, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
't5': (T5Config, TFT5WithLMHeadModel, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP),
"bert": (
BertConfig,
TFBertForPreTraining,
BertForPreTraining,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"bert-large-uncased-whole-word-masking-finetuned-squad": (
BertConfig,
TFBertForQuestionAnswering,
BertForQuestionAnswering,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"bert-large-cased-whole-word-masking-finetuned-squad": (
BertConfig,
TFBertForQuestionAnswering,
BertForQuestionAnswering,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"bert-base-cased-finetuned-mrpc": (
BertConfig,
TFBertForSequenceClassification,
BertForSequenceClassification,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"gpt2": (
GPT2Config,
TFGPT2LMHeadModel,
GPT2LMHeadModel,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"xlnet": (
XLNetConfig,
TFXLNetLMHeadModel,
XLNetLMHeadModel,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"xlm": (
XLMConfig,
TFXLMWithLMHeadModel,
XLMWithLMHeadModel,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"transfo-xl": (
TransfoXLConfig,
TFTransfoXLLMHeadModel,
TransfoXLLMHeadModel,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"openai-gpt": (
OpenAIGPTConfig,
TFOpenAIGPTLMHeadModel,
OpenAIGPTLMHeadModel,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"roberta": (
RobertaConfig,
TFRobertaForMaskedLM,
RobertaForMaskedLM,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"roberta-large-mnli": (
RobertaConfig,
TFRobertaForSequenceClassification,
RobertaForSequenceClassification,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"distilbert": (
DistilBertConfig,
TFDistilBertForMaskedLM,
DistilBertForMaskedLM,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"distilbert-base-uncased-distilled-squad": (
DistilBertConfig,
TFDistilBertForQuestionAnswering,
DistilBertForQuestionAnswering,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"distilbert-base-uncased-distilled-squad": (
DistilBertConfig,
TFDistilBertForQuestionAnswering,
DistilBertForQuestionAnswering,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"ctrl": (
CTRLConfig,
TFCTRLLMHeadModel,
CTRLLMHeadModel,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"albert": (
AlbertConfig,
TFAlbertForMaskedLM,
AlbertForMaskedLM,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"t5": (
T5Config,
TFT5WithLMHeadModel,
T5WithLMHeadModel,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
}
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
def convert_pt_checkpoint_to_tf(
model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True
):
if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
......@@ -116,17 +304,19 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
# Load weights from tf checkpoint
if pytorch_checkpoint_path in aws_model_maps:
pytorch_checkpoint_path = cached_path(aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models)
pytorch_checkpoint_path = cached_path(
aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models
)
# Load PyTorch checkpoint in tf2 model:
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
if compare_with_pt_model:
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
state_dict = torch.load(pytorch_checkpoint_path, map_location='cpu')
pt_model = pt_model_class.from_pretrained(pretrained_model_name_or_path=None,
config=config,
state_dict=state_dict)
state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
pt_model = pt_model_class.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
)
with torch.no_grad():
pto = pt_model(**pt_model.dummy_inputs)
......@@ -139,11 +329,19 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
# Save pytorch-model
print("Save TensorFlow model to {}".format(tf_dump_path))
tf_model.save_weights(tf_dump_path, save_format='h5')
tf_model.save_weights(tf_dump_path, save_format="h5")
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None,
compare_with_pt_model=False, use_cached_models=False, remove_cached_files=False, only_convert_finetuned_models=False):
def convert_all_pt_checkpoints_to_tf(
args_model_type,
tf_dump_path,
model_shortcut_names_or_path=None,
config_shortcut_names_or_path=None,
compare_with_pt_model=False,
use_cached_models=False,
remove_cached_files=False,
only_convert_finetuned_models=False,
):
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
if args_model_type is None:
......@@ -156,7 +354,9 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
print(" Converting model type {}/{}: {}".format(j, len(model_types), model_type))
print("=" * 100)
if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys())))
raise ValueError(
"Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys()))
)
config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
......@@ -166,9 +366,10 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
config_shortcut_names_or_path = model_shortcut_names_or_path
for i, (model_shortcut_name, config_shortcut_name) in enumerate(
zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1):
zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1
):
print("-" * 100)
if '-squad' in model_shortcut_name or '-mrpc' in model_shortcut_name or '-mnli' in model_shortcut_name:
if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name:
if not only_convert_finetuned_models:
print(" Skipping finetuned checkpoint {}".format(model_shortcut_name))
continue
......@@ -176,7 +377,11 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
elif only_convert_finetuned_models:
print(" Skipping not finetuned checkpoint {}".format(model_shortcut_name))
continue
print(" Converting checkpoint {}/{}: {} - model_type {}".format(i, len(aws_config_map), model_shortcut_name, model_type))
print(
" Converting checkpoint {}/{}: {} - model_type {}".format(
i, len(aws_config_map), model_shortcut_name, model_type
)
)
print("-" * 100)
if config_shortcut_name in aws_config_map:
......@@ -190,13 +395,15 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
model_file = cached_path(model_shortcut_name, force_download=not use_cached_models)
if os.path.isfile(model_shortcut_name):
model_shortcut_name = 'converted_model'
model_shortcut_name = "converted_model"
convert_pt_checkpoint_to_tf(model_type=model_type,
pytorch_checkpoint_path=model_file,
config_file=config_file,
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'),
compare_with_pt_model=compare_with_pt_model)
convert_pt_checkpoint_to_tf(
model_type=model_type,
pytorch_checkpoint_path=model_file,
config_file=config_file,
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"),
compare_with_pt_model=compare_with_pt_model,
)
if remove_cached_files:
os.remove(config_file)
os.remove(model_file)
......@@ -205,39 +412,47 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_dump_path",
default = None,
type = str,
required = True,
help = "Path to the output Tensorflow dump file.")
parser.add_argument("--model_type",
default = None,
type = str,
help = "Model type selected in the list of {}. If not given, will download and convert all the models from AWS.".format(list(MODEL_CLASSES.keys())))
parser.add_argument("--pytorch_checkpoint_path",
default = None,
type = str,
help = "Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
"If not given, will download and convert all the checkpoints from AWS.")
parser.add_argument("--config_file",
default = None,
type = str,
help = "The config json file corresponding to the pre-trained model. \n"
"This specifies the model architecture. If not given and "
"--pytorch_checkpoint_path is not given or is a shortcut name"
"use the configuration associated to the shortcut name on the AWS")
parser.add_argument("--compare_with_pt_model",
action='store_true',
help = "Compare Tensorflow and PyTorch model predictions.")
parser.add_argument("--use_cached_models",
action='store_true',
help = "Use cached models if possible instead of updating to latest checkpoint versions.")
parser.add_argument("--remove_cached_files",
action='store_true',
help = "Remove pytorch models after conversion (save memory when converting in batches).")
parser.add_argument("--only_convert_finetuned_models",
action='store_true',
help = "Only convert finetuned models.")
parser.add_argument(
"--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
)
parser.add_argument(
"--model_type",
default=None,
type=str,
help="Model type selected in the list of {}. If not given, will download and convert all the models from AWS.".format(
list(MODEL_CLASSES.keys())
),
)
parser.add_argument(
"--pytorch_checkpoint_path",
default=None,
type=str,
help="Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
"If not given, will download and convert all the checkpoints from AWS.",
)
parser.add_argument(
"--config_file",
default=None,
type=str,
help="The config json file corresponding to the pre-trained model. \n"
"This specifies the model architecture. If not given and "
"--pytorch_checkpoint_path is not given or is a shortcut name"
"use the configuration associated to the shortcut name on the AWS",
)
parser.add_argument(
"--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions."
)
parser.add_argument(
"--use_cached_models",
action="store_true",
help="Use cached models if possible instead of updating to latest checkpoint versions.",
)
parser.add_argument(
"--remove_cached_files",
action="store_true",
help="Remove pytorch models after conversion (save memory when converting in batches).",
)
parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.")
args = parser.parse_args()
# if args.pytorch_checkpoint_path is not None:
......@@ -248,11 +463,15 @@ if __name__ == "__main__":
# compare_with_pt_model=args.compare_with_pt_model,
# use_cached_models=args.use_cached_models)
# else:
convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None,
args.tf_dump_path,
model_shortcut_names_or_path=[args.pytorch_checkpoint_path] if args.pytorch_checkpoint_path is not None else None,
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
compare_with_pt_model=args.compare_with_pt_model,
use_cached_models=args.use_cached_models,
remove_cached_files=args.remove_cached_files,
only_convert_finetuned_models=args.only_convert_finetuned_models)
convert_all_pt_checkpoints_to_tf(
args.model_type.lower() if args.model_type is not None else None,
args.tf_dump_path,
model_shortcut_names_or_path=[args.pytorch_checkpoint_path]
if args.pytorch_checkpoint_path is not None
else None,
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
compare_with_pt_model=args.compare_with_pt_model,
use_cached_models=args.use_cached_models,
remove_cached_files=args.remove_cached_files,
only_convert_finetuned_models=args.only_convert_finetuned_models,
)
......@@ -30,20 +30,27 @@ if version.parse(fairseq.__version__) < version.parse("0.9.0"):
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer
from transformers.modeling_bert import (BertConfig, BertEncoder,
BertIntermediate, BertLayer,
BertModel, BertOutput,
BertSelfAttention,
BertSelfOutput)
from transformers.modeling_roberta import (RobertaEmbeddings,
RobertaForMaskedLM,
RobertaForSequenceClassification,
RobertaModel)
from transformers.modeling_bert import (
BertConfig,
BertEncoder,
BertIntermediate,
BertLayer,
BertModel,
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
from transformers.modeling_roberta import (
RobertaEmbeddings,
RobertaForMaskedLM,
RobertaForSequenceClassification,
RobertaModel,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
SAMPLE_TEXT = 'Hello world! cécé herlolip'
SAMPLE_TEXT = "Hello world! cécé herlolip"
def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path, classification_head):
......@@ -61,7 +68,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
intermediate_size=roberta.args.encoder_ffn_embed_dim,
max_position_embeddings=514,
type_vocab_size=1,
layer_norm_eps=1e-5, # PyTorch default used in fairseq
layer_norm_eps=1e-5, # PyTorch default used in fairseq
)
if classification_head:
config.num_labels = roberta.args.num_classes
......@@ -74,7 +81,9 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
# Embeddings
model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight
model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight
model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them.
model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(
model.roberta.embeddings.token_type_embeddings.weight
) # just zero them out b/c RoBERTa doesn't use them.
model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight
model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias
......@@ -85,11 +94,11 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
### self attention
self_attn: BertSelfAttention = layer.attention.self
assert(
roberta_layer.self_attn.k_proj.weight.data.shape == \
roberta_layer.self_attn.q_proj.weight.data.shape == \
roberta_layer.self_attn.v_proj.weight.data.shape == \
torch.Size((config.hidden_size, config.hidden_size))
assert (
roberta_layer.self_attn.k_proj.weight.data.shape
== roberta_layer.self_attn.q_proj.weight.data.shape
== roberta_layer.self_attn.v_proj.weight.data.shape
== torch.Size((config.hidden_size, config.hidden_size))
)
self_attn.query.weight.data = roberta_layer.self_attn.q_proj.weight
......@@ -101,9 +110,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
### self-attention output
self_output: BertSelfOutput = layer.attention.output
assert(
self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
)
assert self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
self_output.dense.bias = roberta_layer.self_attn.out_proj.bias
self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
......@@ -111,28 +118,24 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
### intermediate
intermediate: BertIntermediate = layer.intermediate
assert(
intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
)
assert intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
intermediate.dense.weight = roberta_layer.fc1.weight
intermediate.dense.bias = roberta_layer.fc1.bias
### output
bert_output: BertOutput = layer.output
assert(
bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
)
assert bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
bert_output.dense.weight = roberta_layer.fc2.weight
bert_output.dense.bias = roberta_layer.fc2.bias
bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias
#### end of layer
if classification_head:
model.classifier.dense.weight = roberta.model.classification_heads['mnli'].dense.weight
model.classifier.dense.bias = roberta.model.classification_heads['mnli'].dense.bias
model.classifier.out_proj.weight = roberta.model.classification_heads['mnli'].out_proj.weight
model.classifier.out_proj.bias = roberta.model.classification_heads['mnli'].out_proj.bias
model.classifier.dense.weight = roberta.model.classification_heads["mnli"].dense.weight
model.classifier.dense.bias = roberta.model.classification_heads["mnli"].dense.bias
model.classifier.out_proj.weight = roberta.model.classification_heads["mnli"].out_proj.weight
model.classifier.out_proj.bias = roberta.model.classification_heads["mnli"].out_proj.bias
else:
# LM Head
model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight
......@@ -143,21 +146,18 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
model.lm_head.bias = roberta.model.decoder.lm_head.bias
# Let's check that we get the same results.
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
our_output = model(input_ids)[0]
if classification_head:
their_output = roberta.model.classification_heads['mnli'](roberta.extract_features(input_ids))
their_output = roberta.model.classification_heads["mnli"](roberta.extract_features(input_ids))
else:
their_output = roberta.model(input_ids)[0]
print(our_output.shape, their_output.shape)
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
success = torch.allclose(our_output, their_output, atol=1e-3)
print(
"Do both models output the same tensors?",
"🔥" if success else "💩"
)
print("Do both models output the same tensors?", "🔥" if success else "💩")
if not success:
raise Exception("Something went wRoNg")
......@@ -169,23 +169,16 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--roberta_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path the official PyTorch dump.")
parser.add_argument("--pytorch_dump_folder_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
parser.add_argument("--classification_head",
action = "store_true",
help = "Whether to convert a final classification head.")
parser.add_argument(
"--roberta_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--classification_head", action="store_true", help="Whether to convert a final classification head."
)
args = parser.parse_args()
convert_roberta_checkpoint_to_pytorch(
args.roberta_checkpoint_path,
args.pytorch_dump_folder_path,
args.classification_head
args.roberta_checkpoint_path, args.pytorch_dump_folder_path, args.classification_head
)
......@@ -24,8 +24,10 @@ import torch
from transformers import T5Config, T5Model, load_tf_weights_in_t5
import logging
logging.basicConfig(level=logging.INFO)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
# Initialise PyTorch model
config = T5Config.from_json_file(config_file)
......@@ -43,23 +45,19 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path to the TensorFlow checkpoint path.")
parser.add_argument("--config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained T5 model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained T5 model. \n"
"This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.config_file,
args.pytorch_dump_path)
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
......@@ -26,9 +26,8 @@ import torch
import transformers.tokenization_transfo_xl as data_utils
from transformers import CONFIG_NAME, WEIGHTS_NAME
from transformers import (TransfoXLConfig, TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl)
from transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES)
from transformers import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl
from transformers.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
if sys.version_info[0] == 2:
import cPickle as pickle
......@@ -36,32 +35,33 @@ else:
import pickle
import logging
logging.basicConfig(level=logging.INFO)
# We do this to be able to load python 2 datasets pickles
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
data_utils.Vocab = data_utils.TransfoXLTokenizer
data_utils.Corpus = data_utils.TransfoXLCorpus
sys.modules['data_utils'] = data_utils
sys.modules['vocabulary'] = data_utils
sys.modules["data_utils"] = data_utils
sys.modules["vocabulary"] = data_utils
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
transfo_xl_config_file,
pytorch_dump_folder_path,
transfo_xl_dataset_file):
def convert_transfo_xl_checkpoint_to_pytorch(
tf_checkpoint_path, transfo_xl_config_file, pytorch_dump_folder_path, transfo_xl_dataset_file
):
if transfo_xl_dataset_file:
# Convert a pre-processed corpus (see original TensorFlow repo)
with open(transfo_xl_dataset_file, "rb") as fp:
corpus = pickle.load(fp, encoding="latin1")
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file']
pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["pretrained_vocab_file"]
print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
corpus_vocab_dict = corpus.vocab.__dict__
torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
corpus_dict_no_vocab = corpus.__dict__
corpus_dict_no_vocab.pop('vocab', None)
pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME
corpus_dict_no_vocab.pop("vocab", None)
pytorch_dataset_dump_path = pytorch_dump_folder_path + "/" + CORPUS_NAME
print("Save dataset to {}".format(pytorch_dataset_dump_path))
torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
......@@ -92,26 +92,36 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path",
default = None,
type = str,
required = True,
help = "Path to the folder to store the PyTorch model or dataset/vocab.")
parser.add_argument("--tf_checkpoint_path",
default = "",
type = str,
help = "An optional path to a TensorFlow checkpoint path to be converted.")
parser.add_argument("--transfo_xl_config_file",
default = "",
type = str,
help = "An optional config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--transfo_xl_dataset_file",
default = "",
type = str,
help = "An optional dataset file to be converted in a vocabulary.")
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
required=True,
help="Path to the folder to store the PyTorch model or dataset/vocab.",
)
parser.add_argument(
"--tf_checkpoint_path",
default="",
type=str,
help="An optional path to a TensorFlow checkpoint path to be converted.",
)
parser.add_argument(
"--transfo_xl_config_file",
default="",
type=str,
help="An optional config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.",
)
parser.add_argument(
"--transfo_xl_dataset_file",
default="",
type=str,
help="An optional dataset file to be converted in a vocabulary.",
)
args = parser.parse_args()
convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.transfo_xl_config_file,
args.pytorch_dump_folder_path,
args.transfo_xl_dataset_file)
convert_transfo_xl_checkpoint_to_pytorch(
args.tf_checkpoint_path,
args.transfo_xl_config_file,
args.pytorch_dump_folder_path,
args.transfo_xl_dataset_file,
)
......@@ -27,32 +27,34 @@ from transformers import CONFIG_NAME, WEIGHTS_NAME
from transformers.tokenization_xlm import VOCAB_FILES_NAMES
import logging
logging.basicConfig(level=logging.INFO)
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
# Load checkpoint
chkpt = torch.load(xlm_checkpoint_path, map_location='cpu')
chkpt = torch.load(xlm_checkpoint_path, map_location="cpu")
state_dict = chkpt['model']
state_dict = chkpt["model"]
# We have the base model one level deeper than the original XLM repository
two_levels_state_dict = {}
for k, v in state_dict.items():
if 'pred_layer' in k:
if "pred_layer" in k:
two_levels_state_dict[k] = v
else:
two_levels_state_dict['transformer.' + k] = v
two_levels_state_dict["transformer." + k] = v
config = chkpt['params']
config = chkpt["params"]
config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray)))
vocab = chkpt['dico_word2id']
vocab = dict((s + '</w>' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in vocab.items())
vocab = chkpt["dico_word2id"]
vocab = dict((s + "</w>" if s.find("@@") == -1 and i > 13 else s.replace("@@", ""), i) for s, i in vocab.items())
# Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file']
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["vocab_file"]
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(two_levels_state_dict, pytorch_weights_dump_path)
......@@ -69,15 +71,11 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--xlm_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path the official PyTorch dump.")
parser.add_argument("--pytorch_dump_folder_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
parser.add_argument(
"--xlm_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path)
......@@ -22,11 +22,15 @@ import os
import argparse
import torch
from transformers import (CONFIG_NAME, WEIGHTS_NAME,
XLNetConfig,
XLNetLMHeadModel, XLNetForQuestionAnswering,
XLNetForSequenceClassification,
load_tf_weights_in_xlnet)
from transformers import (
CONFIG_NAME,
WEIGHTS_NAME,
XLNetConfig,
XLNetLMHeadModel,
XLNetForQuestionAnswering,
XLNetForSequenceClassification,
load_tf_weights_in_xlnet,
)
GLUE_TASKS_NUM_LABELS = {
"cola": 2,
......@@ -41,9 +45,13 @@ GLUE_TASKS_NUM_LABELS = {
}
import logging
logging.basicConfig(level=logging.INFO)
def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None):
def convert_xlnet_checkpoint_to_pytorch(
tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None
):
# Initialise PyTorch model
config = XLNetConfig.from_json_file(bert_config_file)
......@@ -53,7 +61,7 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
config.finetuning_task = finetuning_task
config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task]
model = XLNetForSequenceClassification(config)
elif 'squad' in finetuning_task:
elif "squad" in finetuning_task:
config.finetuning_task = finetuning_task
model = XLNetForQuestionAnswering(config)
else:
......@@ -75,30 +83,33 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path to the TensorFlow checkpoint path.")
parser.add_argument("--xlnet_config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained XLNet model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_folder_path",
default = None,
type = str,
required = True,
help = "Path to the folder to store the PyTorch model or dataset/vocab.")
parser.add_argument("--finetuning_task",
default = None,
type = str,
help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned")
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--xlnet_config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained XLNet model. \n"
"This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
required=True,
help="Path to the folder to store the PyTorch model or dataset/vocab.",
)
parser.add_argument(
"--finetuning_task",
default=None,
type=str,
help="Name of a task on which the XLNet TensorFloaw model was fine-tuned",
)
args = parser.parse_args()
print(args)
convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.xlnet_config_file,
args.pytorch_dump_folder_path,
args.finetuning_task)
convert_xlnet_checkpoint_to_pytorch(
args.tf_checkpoint_path, args.xlnet_config_file, args.pytorch_dump_folder_path, args.finetuning_task
)
from .processors import InputExample, InputFeatures, DataProcessor, SquadFeatures, SingleSentenceClassificationProcessor
from .processors import (
InputExample,
InputFeatures,
DataProcessor,
SquadFeatures,
SingleSentenceClassificationProcessor,
)
from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
from .processors import squad_convert_examples_to_features, SquadExample, SquadV1Processor, SquadV2Processor
from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
from .metrics import is_sklearn_available
if is_sklearn_available():
from .metrics import glue_compute_metrics, xnli_compute_metrics
......@@ -23,20 +23,22 @@ logger = logging.getLogger(__name__)
try:
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score
_has_sklearn = True
except (AttributeError, ImportError) as e:
logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html")
_has_sklearn = False
def is_sklearn_available():
return _has_sklearn
if _has_sklearn:
def simple_accuracy(preds, labels):
return (preds == labels).mean()
def acc_and_f1(preds, labels):
acc = simple_accuracy(preds, labels)
f1 = f1_score(y_true=labels, y_pred=preds)
......@@ -46,7 +48,6 @@ if _has_sklearn:
"acc_and_f1": (acc + f1) / 2,
}
def pearson_and_spearman(preds, labels):
pearson_corr = pearsonr(preds, labels)[0]
spearman_corr = spearmanr(preds, labels)[0]
......@@ -56,7 +57,6 @@ if _has_sklearn:
"corr": (pearson_corr + spearman_corr) / 2,
}
def glue_compute_metrics(task_name, preds, labels):
assert len(preds) == len(labels)
if task_name == "cola":
......@@ -82,7 +82,6 @@ if _has_sklearn:
else:
raise KeyError(task_name)
def xnli_compute_metrics(task_name, preds, labels):
assert len(preds) == len(labels)
if task_name == "xnli":
......
......@@ -24,19 +24,21 @@ logger = logging.getLogger(__name__)
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
return re.sub(regex, ' ', text)
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, " ", text)
def white_space_fix(text):
return ' '.join(text.split())
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
......@@ -75,14 +77,14 @@ def get_raw_scores(examples, preds):
for example in examples:
qas_id = example.qas_id
gold_answers = [answer['text'] for answer in example.answers if normalize_answer(answer['text'])]
gold_answers = [answer["text"] for answer in example.answers if normalize_answer(answer["text"])]
if not gold_answers:
# For unanswerable questions, only correct answer is empty string
gold_answers = ['']
gold_answers = [""]
if qas_id not in preds:
print('Missing prediction for %s' % qas_id)
print("Missing prediction for %s" % qas_id)
continue
prediction = preds[qas_id]
......@@ -106,23 +108,27 @@ def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
if not qid_list:
total = len(exact_scores)
return collections.OrderedDict([
('exact', 100.0 * sum(exact_scores.values()) / total),
('f1', 100.0 * sum(f1_scores.values()) / total),
('total', total),
])
return collections.OrderedDict(
[
("exact", 100.0 * sum(exact_scores.values()) / total),
("f1", 100.0 * sum(f1_scores.values()) / total),
("total", total),
]
)
else:
total = len(qid_list)
return collections.OrderedDict([
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
('total', total),
])
return collections.OrderedDict(
[
("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
("total", total),
]
)
def merge_eval(main_eval, new_eval, prefix):
for k in new_eval:
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
main_eval["%s_%s" % (prefix, k)] = new_eval[k]
def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
......@@ -160,16 +166,14 @@ def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(
preds, exact_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(
preds, f1_raw, na_probs, qid_to_has_ans)
main_eval['best_exact'] = best_exact
main_eval['best_exact_thresh'] = exact_thresh
main_eval['best_f1'] = best_f1
main_eval['best_f1_thresh'] = f1_thresh
main_eval['has_ans_exact'] = has_ans_exact
main_eval['has_ans_f1'] = has_ans_f1
best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
main_eval["best_exact"] = best_exact
main_eval["best_exact_thresh"] = exact_thresh
main_eval["best_f1"] = best_f1
main_eval["best_f1_thresh"] = f1_thresh
main_eval["has_ans_exact"] = has_ans_exact
main_eval["has_ans_f1"] = has_ans_f1
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
......@@ -199,10 +203,10 @@ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_h
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
main_eval['best_exact'] = best_exact
main_eval['best_exact_thresh'] = exact_thresh
main_eval['best_f1'] = best_f1
main_eval['best_f1_thresh'] = f1_thresh
main_eval["best_exact"] = best_exact
main_eval["best_exact_thresh"] = exact_thresh
main_eval["best_f1"] = best_f1
main_eval["best_f1_thresh"] = f1_thresh
def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0):
......@@ -215,18 +219,20 @@ def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_
exact, f1 = get_raw_scores(examples, preds)
exact_threshold = apply_no_ans_threshold(exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
exact_threshold = apply_no_ans_threshold(
exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold
)
f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
evaluation = make_eval_dict(exact_threshold, f1_threshold)
if has_answer_qids:
has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids)
merge_eval(evaluation, has_ans_eval, 'HasAns')
merge_eval(evaluation, has_ans_eval, "HasAns")
if no_answer_qids:
no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids)
merge_eval(evaluation, no_ans_eval, 'NoAns')
merge_eval(evaluation, no_ans_eval, "NoAns")
if no_answer_probs:
find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer)
......@@ -284,8 +290,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
start_position = tok_text.find(pred_text)
if start_position == -1:
if verbose_logging:
logger.info(
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
logger.info("Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
return orig_text
end_position = start_position + len(pred_text) - 1
......@@ -294,8 +299,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
if len(orig_ns_text) != len(tok_ns_text):
if verbose_logging:
logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
orig_ns_text, tok_ns_text)
logger.info("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text)
return orig_text
# We then project the characters in `pred_text` back to `orig_text` using
......@@ -326,7 +330,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
logger.info("Couldn't map end position")
return orig_text
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
output_text = orig_text[orig_start_position : (orig_end_position + 1)]
return output_text
......@@ -393,8 +397,8 @@ def compute_predictions_logits(
unique_id_to_result[result.unique_id] = result
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
"PrelimPrediction",
["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
"PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]
)
all_predictions = collections.OrderedDict()
all_nbest_json = collections.OrderedDict()
......@@ -447,7 +451,9 @@ def compute_predictions_logits(
start_index=start_index,
end_index=end_index,
start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index]))
end_logit=result.end_logits[end_index],
)
)
if version_2_with_negative:
prelim_predictions.append(
_PrelimPrediction(
......@@ -455,14 +461,14 @@ def compute_predictions_logits(
start_index=0,
end_index=0,
start_logit=null_start_logit,
end_logit=null_end_logit))
prelim_predictions = sorted(
prelim_predictions,
key=lambda x: (x.start_logit + x.end_logit),
reverse=True)
end_logit=null_end_logit,
)
)
prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
"NbestPrediction", ["text", "start_logit", "end_logit"])
"NbestPrediction", ["text", "start_logit", "end_logit"]
)
seen_predictions = {}
nbest = []
......@@ -471,10 +477,10 @@ def compute_predictions_logits(
break
feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
......@@ -498,31 +504,21 @@ def compute_predictions_logits(
final_text = ""
seen_predictions[final_text] = True
nbest.append(
_NbestPrediction(
text=final_text,
start_logit=pred.start_logit,
end_logit=pred.end_logit))
nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit))
# if we didn't include the empty option in the n-best, include it
if version_2_with_negative:
if "" not in seen_predictions:
nbest.append(
_NbestPrediction(
text="",
start_logit=null_start_logit,
end_logit=null_end_logit))
nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit))
# In very rare edge cases we could only have single null prediction.
# So we just create a nonce prediction in this case to avoid failure.
if len(nbest) == 1:
nbest.insert(0,
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if not nbest:
nbest.append(
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
assert len(nbest) >= 1
......@@ -551,8 +547,7 @@ def compute_predictions_logits(
all_predictions[example.qas_id] = nbest_json[0]["text"]
else:
# predict "" iff the null score - the score of best non-null > threshold
score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit)
score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
......@@ -586,7 +581,7 @@ def compute_predictions_log_probs(
end_n_top,
version_2_with_negative,
tokenizer,
verbose_logging
verbose_logging,
):
""" XLNet write prediction logic (more complex than Bert's).
Write final predictions to the json file and log-odds of null if needed.
......@@ -594,12 +589,12 @@ def compute_predictions_log_probs(
Requires utils_squad_evaluate.py
"""
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
"PrelimPrediction",
["feature_index", "start_index", "end_index",
"start_log_prob", "end_log_prob"])
"PrelimPrediction", ["feature_index", "start_index", "end_index", "start_log_prob", "end_log_prob"]
)
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"])
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"]
)
logger.info("Writing predictions to: %s", output_prediction_file)
# logger.info("Writing nbest to: %s" % (output_nbest_file))
......@@ -663,12 +658,13 @@ def compute_predictions_log_probs(
start_index=start_index,
end_index=end_index,
start_log_prob=start_log_prob,
end_log_prob=end_log_prob))
end_log_prob=end_log_prob,
)
)
prelim_predictions = sorted(
prelim_predictions,
key=lambda x: (x.start_log_prob + x.end_log_prob),
reverse=True)
prelim_predictions, key=lambda x: (x.start_log_prob + x.end_log_prob), reverse=True
)
seen_predictions = {}
nbest = []
......@@ -688,10 +684,10 @@ def compute_predictions_log_probs(
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
# Previously used Bert untokenizer
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
# Clean whitespace
......@@ -704,8 +700,7 @@ def compute_predictions_log_probs(
else:
do_lower_case = tokenizer.do_lowercase_and_remove_accent
final_text = get_final_text(tok_text, orig_text, do_lower_case,
verbose_logging)
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
if final_text in seen_predictions:
continue
......@@ -713,17 +708,13 @@ def compute_predictions_log_probs(
seen_predictions[final_text] = True
nbest.append(
_NbestPrediction(
text=final_text,
start_log_prob=pred.start_log_prob,
end_log_prob=pred.end_log_prob))
_NbestPrediction(text=final_text, start_log_prob=pred.start_log_prob, end_log_prob=pred.end_log_prob)
)
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if not nbest:
nbest.append(
_NbestPrediction(text="", start_log_prob=-1e6,
end_log_prob=-1e6))
nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
total_scores = []
best_non_null_entry = None
......
from .utils import InputExample, InputFeatures, DataProcessor, SingleSentenceClassificationProcessor
from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
from .squad import squad_convert_examples_to_features, SquadFeatures, SquadExample, SquadV1Processor, SquadV2Processor
from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
\ No newline at end of file
from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
......@@ -27,15 +27,18 @@ if is_tf_available():
logger = logging.getLogger(__name__)
def glue_convert_examples_to_features(examples, tokenizer,
max_length=512,
task=None,
label_list=None,
output_mode=None,
pad_on_left=False,
pad_token=0,
pad_token_segment_id=0,
mask_padding_with_zero=True):
def glue_convert_examples_to_features(
examples,
tokenizer,
max_length=512,
task=None,
label_list=None,
output_mode=None,
pad_on_left=False,
pad_token=0,
pad_token_segment_id=0,
mask_padding_with_zero=True,
):
"""
Loads a data file into a list of ``InputFeatures``
......@@ -82,12 +85,7 @@ def glue_convert_examples_to_features(examples, tokenizer,
example = processor.get_example_from_tensor_dict(example)
example = processor.tfds_map(example)
inputs = tokenizer.encode_plus(
example.text_a,
example.text_b,
add_special_tokens=True,
max_length=max_length,
)
inputs = tokenizer.encode_plus(example.text_a, example.text_b, add_special_tokens=True, max_length=max_length,)
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
# The mask has 1 for real tokens and 0 for padding tokens. Only real
......@@ -106,8 +104,12 @@ def glue_convert_examples_to_features(examples, tokenizer,
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length)
assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids), max_length)
assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(
len(attention_mask), max_length
)
assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(
len(token_type_ids), max_length
)
if output_mode == "classification":
label = label_map[example.label]
......@@ -125,28 +127,36 @@ def glue_convert_examples_to_features(examples, tokenizer,
logger.info("label: %s (id = %d)" % (example.label, label))
features.append(
InputFeatures(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
label=label))
InputFeatures(
input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, label=label
)
)
if is_tf_available() and is_tf_dataset:
def gen():
for ex in features:
yield ({'input_ids': ex.input_ids,
'attention_mask': ex.attention_mask,
'token_type_ids': ex.token_type_ids},
ex.label)
return tf.data.Dataset.from_generator(gen,
({'input_ids': tf.int32,
'attention_mask': tf.int32,
'token_type_ids': tf.int32},
tf.int64),
({'input_ids': tf.TensorShape([None]),
'attention_mask': tf.TensorShape([None]),
'token_type_ids': tf.TensorShape([None])},
tf.TensorShape([])))
yield (
{
"input_ids": ex.input_ids,
"attention_mask": ex.attention_mask,
"token_type_ids": ex.token_type_ids,
},
ex.label,
)
return tf.data.Dataset.from_generator(
gen,
({"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, tf.int64),
(
{
"input_ids": tf.TensorShape([None]),
"attention_mask": tf.TensorShape([None]),
"token_type_ids": tf.TensorShape([None]),
},
tf.TensorShape([]),
),
)
return features
......@@ -156,21 +166,21 @@ class MrpcProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence1'].numpy().decode('utf-8'),
tensor_dict['sentence2'].numpy().decode('utf-8'),
str(tensor_dict['label'].numpy()))
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence1"].numpy().decode("utf-8"),
tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""See base class."""
......@@ -186,8 +196,7 @@ class MrpcProcessor(DataProcessor):
text_a = line[3]
text_b = line[4]
label = line[0]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -196,21 +205,20 @@ class MnliProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['premise'].numpy().decode('utf-8'),
tensor_dict['hypothesis'].numpy().decode('utf-8'),
str(tensor_dict['label'].numpy()))
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["premise"].numpy().decode("utf-8"),
tensor_dict["hypothesis"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
def get_labels(self):
"""See base class."""
......@@ -226,8 +234,7 @@ class MnliProcessor(DataProcessor):
text_a = line[8]
text_b = line[9]
label = line[-1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -236,9 +243,7 @@ class MnliMismatchedProcessor(MnliProcessor):
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
"dev_matched")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_matched")
class ColaProcessor(DataProcessor):
......@@ -246,20 +251,20 @@ class ColaProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence'].numpy().decode('utf-8'),
None,
str(tensor_dict['label'].numpy()))
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence"].numpy().decode("utf-8"),
None,
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""See base class."""
......@@ -272,8 +277,7 @@ class ColaProcessor(DataProcessor):
guid = "%s-%s" % (set_type, i)
text_a = line[3]
label = line[1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
......@@ -282,20 +286,20 @@ class Sst2Processor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence'].numpy().decode('utf-8'),
None,
str(tensor_dict['label'].numpy()))
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence"].numpy().decode("utf-8"),
None,
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""See base class."""
......@@ -310,8 +314,7 @@ class Sst2Processor(DataProcessor):
guid = "%s-%s" % (set_type, i)
text_a = line[0]
label = line[1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
......@@ -320,20 +323,20 @@ class StsbProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence1'].numpy().decode('utf-8'),
tensor_dict['sentence2'].numpy().decode('utf-8'),
str(tensor_dict['label'].numpy()))
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence1"].numpy().decode("utf-8"),
tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""See base class."""
......@@ -349,8 +352,7 @@ class StsbProcessor(DataProcessor):
text_a = line[7]
text_b = line[8]
label = line[-1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -359,20 +361,20 @@ class QqpProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['question1'].numpy().decode('utf-8'),
tensor_dict['question2'].numpy().decode('utf-8'),
str(tensor_dict['label'].numpy()))
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["question1"].numpy().decode("utf-8"),
tensor_dict["question2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""See base class."""
......@@ -391,8 +393,7 @@ class QqpProcessor(DataProcessor):
label = line[5]
except IndexError:
continue
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -401,21 +402,20 @@ class QnliProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['question'].numpy().decode('utf-8'),
tensor_dict['sentence'].numpy().decode('utf-8'),
str(tensor_dict['label'].numpy()))
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["question"].numpy().decode("utf-8"),
tensor_dict["sentence"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")),
"dev_matched")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
def get_labels(self):
"""See base class."""
......@@ -431,8 +431,7 @@ class QnliProcessor(DataProcessor):
text_a = line[1]
text_b = line[2]
label = line[-1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -441,20 +440,20 @@ class RteProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence1'].numpy().decode('utf-8'),
tensor_dict['sentence2'].numpy().decode('utf-8'),
str(tensor_dict['label'].numpy()))
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence1"].numpy().decode("utf-8"),
tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""See base class."""
......@@ -470,8 +469,7 @@ class RteProcessor(DataProcessor):
text_a = line[1]
text_b = line[2]
label = line[-1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -480,20 +478,20 @@ class WnliProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence1'].numpy().decode('utf-8'),
tensor_dict['sentence2'].numpy().decode('utf-8'),
str(tensor_dict['label'].numpy()))
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence1"].numpy().decode("utf-8"),
tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""See base class."""
......@@ -509,10 +507,10 @@ class WnliProcessor(DataProcessor):
text_a = line[1]
text_b = line[2]
label = line[-1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
glue_tasks_num_labels = {
"cola": 2,
"mnli": 3,
......
......@@ -82,8 +82,8 @@ def _is_whitespace(c):
return True
return False
def squad_convert_example_to_features(example, max_seq_length,
doc_stride, max_query_length, is_training):
def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_query_length, is_training):
features = []
if is_training and not example.is_impossible:
# Get start and end position
......@@ -91,7 +91,7 @@ def squad_convert_example_to_features(example, max_seq_length,
end_position = example.end_position
# If the answer cannot be found in the text, then skip this example.
actual_text = " ".join(example.doc_tokens[start_position:(end_position + 1)])
actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)])
cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
if actual_text.find(cleaned_answer_text) == -1:
logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
......@@ -121,8 +121,11 @@ def squad_convert_example_to_features(example, max_seq_length,
spans = []
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length)
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence + 1 \
if 'roberta' in str(type(tokenizer)) else tokenizer.max_len - tokenizer.max_len_single_sentence
sequence_added_tokens = (
tokenizer.max_len - tokenizer.max_len_single_sentence + 1
if "roberta" in str(type(tokenizer))
else tokenizer.max_len - tokenizer.max_len_single_sentence
)
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
span_doc_tokens = all_doc_tokens
......@@ -135,16 +138,18 @@ def squad_convert_example_to_features(example, max_seq_length,
return_overflowing_tokens=True,
pad_to_max_length=True,
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
truncation_strategy='only_second' if tokenizer.padding_side == "right" else 'only_first'
truncation_strategy="only_second" if tokenizer.padding_side == "right" else "only_first",
)
paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride,
max_seq_length - len(truncated_query) - sequence_pair_added_tokens)
paragraph_len = min(
len(all_doc_tokens) - len(spans) * doc_stride,
max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
)
if tokenizer.pad_token_id in encoded_dict['input_ids']:
non_padded_ids = encoded_dict['input_ids'][:encoded_dict['input_ids'].index(tokenizer.pad_token_id)]
if tokenizer.pad_token_id in encoded_dict["input_ids"]:
non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
else:
non_padded_ids = encoded_dict['input_ids']
non_padded_ids = encoded_dict["input_ids"]
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
......@@ -170,17 +175,20 @@ def squad_convert_example_to_features(example, max_seq_length,
for doc_span_index in range(len(spans)):
for j in range(spans[doc_span_index]["paragraph_len"]):
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
index = j if tokenizer.padding_side == "left" else spans[doc_span_index][
"truncated_query_with_special_tokens_length"] + j
index = (
j
if tokenizer.padding_side == "left"
else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
)
spans[doc_span_index]["token_is_max_context"][index] = is_max_context
for span in spans:
# Identify the position of the CLS token
cls_index = span['input_ids'].index(tokenizer.cls_token_id)
cls_index = span["input_ids"].index(tokenizer.cls_token_id)
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask = np.array(span['token_type_ids'])
p_mask = np.array(span["token_type_ids"])
p_mask = np.minimum(p_mask, 1)
......@@ -219,31 +227,34 @@ def squad_convert_example_to_features(example, max_seq_length,
start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset
features.append(SquadFeatures(
span['input_ids'],
span['attention_mask'],
span['token_type_ids'],
cls_index,
p_mask.tolist(),
example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing.
unique_id=0,
paragraph_len=span['paragraph_len'],
token_is_max_context=span["token_is_max_context"],
tokens=span["tokens"],
token_to_orig_map=span["token_to_orig_map"],
start_position=start_position,
end_position=end_position
))
features.append(
SquadFeatures(
span["input_ids"],
span["attention_mask"],
span["token_type_ids"],
cls_index,
p_mask.tolist(),
example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing.
unique_id=0,
paragraph_len=span["paragraph_len"],
token_is_max_context=span["token_is_max_context"],
tokens=span["tokens"],
token_to_orig_map=span["token_to_orig_map"],
start_position=start_position,
end_position=end_position,
)
)
return features
def squad_convert_example_to_features_init(tokenizer_for_convert):
global tokenizer
tokenizer = tokenizer_for_convert
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
doc_stride, max_query_length, is_training,
return_dataset=False, threads=1):
def squad_convert_examples_to_features(
examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training, return_dataset=False, threads=1
):
"""
Converts a list of examples into a list of features that can be directly given as input to a model.
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
......@@ -279,17 +290,28 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
)
"""
# Defining helper methods
# Defining helper methods
features = []
threads = min(threads, cpu_count())
with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p:
annotate_ = partial(squad_convert_example_to_features, max_seq_length=max_seq_length,
doc_stride=doc_stride, max_query_length=max_query_length, is_training=is_training)
features = list(tqdm(p.imap(annotate_, examples, chunksize=32), total=len(examples), desc='convert squad examples to features'))
annotate_ = partial(
squad_convert_example_to_features,
max_seq_length=max_seq_length,
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=is_training,
)
features = list(
tqdm(
p.imap(annotate_, examples, chunksize=32),
total=len(examples),
desc="convert squad examples to features",
)
)
new_features = []
unique_id = 1000000000
example_index = 0
for example_features in tqdm(features, total=len(features), desc='add example index and unique id'):
for example_features in tqdm(features, total=len(features), desc="add example index and unique id"):
if not example_features:
continue
for example_feature in example_features:
......@@ -300,7 +322,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
example_index += 1
features = new_features
del new_features
if return_dataset == 'pt':
if return_dataset == "pt":
if not is_torch_available():
raise ImportError("Pytorch must be installed to return a pytorch dataset.")
......@@ -341,12 +363,13 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
"input_ids": ex.input_ids,
"attention_mask": ex.attention_mask,
"token_type_ids": ex.token_type_ids,
}, {
},
{
"start_position": ex.start_position,
"end_position": ex.end_position,
"cls_index": ex.cls_index,
"p_mask": ex.p_mask,
}
},
)
return tf.data.Dataset.from_generator(
......
......@@ -24,6 +24,7 @@ from ...file_utils import is_tf_available, is_torch_available
logger = logging.getLogger(__name__)
class InputExample(object):
"""
A single training/test example for simple sequence classification.
......@@ -37,6 +38,7 @@ class InputExample(object):
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
def __init__(self, guid, text_a, text_b=None, label=None):
self.guid = guid
self.text_a = text_a
......@@ -99,14 +101,15 @@ class DataProcessor(object):
lines = []
for line in reader:
if sys.version_info[0] == 2:
line = list(unicode(cell, 'utf-8') for cell in line)
line = list(unicode(cell, "utf-8") for cell in line)
lines.append(line)
return lines
class SingleSentenceClassificationProcessor(DataProcessor):
""" Generic processor for a single sentence classification data set."""
def __init__(self, labels=None, examples=None, mode='classification', verbose=False):
def __init__(self, labels=None, examples=None, mode="classification", verbose=False):
self.labels = [] if labels is None else labels
self.examples = [] if examples is None else examples
self.mode = mode
......@@ -117,22 +120,24 @@ class SingleSentenceClassificationProcessor(DataProcessor):
def __getitem__(self, idx):
if isinstance(idx, slice):
return SingleSentenceClassificationProcessor(labels=self.labels,
examples=self.examples[idx])
return SingleSentenceClassificationProcessor(labels=self.labels, examples=self.examples[idx])
return self.examples[idx]
@classmethod
def create_from_csv(cls, file_name, split_name='', column_label=0, column_text=1,
column_id=None, skip_first_row=False, **kwargs):
def create_from_csv(
cls, file_name, split_name="", column_label=0, column_text=1, column_id=None, skip_first_row=False, **kwargs
):
processor = cls(**kwargs)
processor.add_examples_from_csv(file_name,
split_name=split_name,
column_label=column_label,
column_text=column_text,
column_id=column_id,
skip_first_row=skip_first_row,
overwrite_labels=True,
overwrite_examples=True)
processor.add_examples_from_csv(
file_name,
split_name=split_name,
column_label=column_label,
column_text=column_text,
column_id=column_id,
skip_first_row=skip_first_row,
overwrite_labels=True,
overwrite_examples=True,
)
return processor
@classmethod
......@@ -141,8 +146,17 @@ class SingleSentenceClassificationProcessor(DataProcessor):
processor.add_examples(texts_or_text_and_labels, labels=labels)
return processor
def add_examples_from_csv(self, file_name, split_name='', column_label=0, column_text=1, column_id=None,
skip_first_row=False, overwrite_labels=False, overwrite_examples=False):
def add_examples_from_csv(
self,
file_name,
split_name="",
column_label=0,
column_text=1,
column_id=None,
skip_first_row=False,
overwrite_labels=False,
overwrite_examples=False,
):
lines = self._read_tsv(file_name)
if skip_first_row:
lines = lines[1:]
......@@ -158,10 +172,13 @@ class SingleSentenceClassificationProcessor(DataProcessor):
guid = "%s-%s" % (split_name, i) if split_name else "%s" % i
ids.append(guid)
return self.add_examples(texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples)
return self.add_examples(
texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples
)
def add_examples(self, texts_or_text_and_labels, labels=None, ids=None,
overwrite_labels=False, overwrite_examples=False):
def add_examples(
self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
):
assert labels is None or len(texts_or_text_and_labels) == len(labels)
assert ids is None or len(texts_or_text_and_labels) == len(ids)
if ids is None:
......@@ -192,13 +209,15 @@ class SingleSentenceClassificationProcessor(DataProcessor):
return self.examples
def get_features(self,
tokenizer,
max_length=None,
pad_on_left=False,
pad_token=0,
mask_padding_with_zero=True,
return_tensors=None):
def get_features(
self,
tokenizer,
max_length=None,
pad_on_left=False,
pad_token=0,
mask_padding_with_zero=True,
return_tensors=None,
):
"""
Convert examples in a list of ``InputFeatures``
......@@ -231,9 +250,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
logger.info("Tokenizing example %d", ex_index)
input_ids = tokenizer.encode(
example.text_a,
add_special_tokens=True,
max_length=min(max_length, tokenizer.max_len),
example.text_a, add_special_tokens=True, max_length=min(max_length, tokenizer.max_len),
)
all_input_ids.append(input_ids)
......@@ -256,8 +273,12 @@ class SingleSentenceClassificationProcessor(DataProcessor):
input_ids = input_ids + ([pad_token] * padding_length)
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
assert len(input_ids) == batch_length, "Error with input length {} vs {}".format(len(input_ids), batch_length)
assert len(attention_mask) == batch_length, "Error with input length {} vs {}".format(len(attention_mask), batch_length)
assert len(input_ids) == batch_length, "Error with input length {} vs {}".format(
len(input_ids), batch_length
)
assert len(attention_mask) == batch_length, "Error with input length {} vs {}".format(
len(attention_mask), batch_length
)
if self.mode == "classification":
label = label_map[example.label]
......@@ -273,36 +294,31 @@ class SingleSentenceClassificationProcessor(DataProcessor):
logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
logger.info("label: %s (id = %d)" % (example.label, label))
features.append(
InputFeatures(input_ids=input_ids,
attention_mask=attention_mask,
label=label))
features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, label=label))
if return_tensors is None:
return features
elif return_tensors == 'tf':
elif return_tensors == "tf":
if not is_tf_available():
raise ImportError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
import tensorflow as tf
def gen():
for ex in features:
yield ({'input_ids': ex.input_ids,
'attention_mask': ex.attention_mask},
ex.label)
dataset = tf.data.Dataset.from_generator(gen,
({'input_ids': tf.int32,
'attention_mask': tf.int32},
tf.int64),
({'input_ids': tf.TensorShape([None]),
'attention_mask': tf.TensorShape([None])},
tf.TensorShape([])))
yield ({"input_ids": ex.input_ids, "attention_mask": ex.attention_mask}, ex.label)
dataset = tf.data.Dataset.from_generator(
gen,
({"input_ids": tf.int32, "attention_mask": tf.int32}, tf.int64),
({"input_ids": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None])}, tf.TensorShape([])),
)
return dataset
elif return_tensors == 'pt':
elif return_tensors == "pt":
if not is_torch_available():
raise ImportError("return_tensors set to 'pt' but PyTorch can't be imported")
import torch
from torch.utils.data import TensorDataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
if self.mode == "classification":
......
......@@ -24,11 +24,12 @@ from .utils import DataProcessor, InputExample
logger = logging.getLogger(__name__)
class XnliProcessor(DataProcessor):
"""Processor for the XNLI dataset.
Adapted from https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207"""
def __init__(self, language, train_language = None):
def __init__(self, language, train_language=None):
self.language = language
self.train_language = train_language
......@@ -40,13 +41,12 @@ class XnliProcessor(DataProcessor):
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % ('train', i)
guid = "%s-%s" % ("train", i)
text_a = line[0]
text_b = line[1]
label = "contradiction" if line[2] == "contradictory" else line[2]
assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
......@@ -59,19 +59,19 @@ class XnliProcessor(DataProcessor):
language = line[0]
if language != self.language:
continue
guid = "%s-%s" % ('test', i)
guid = "%s-%s" % ("test", i)
text_a = line[6]
text_b = line[7]
label = line[1]
assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
xnli_processors = {
"xnli": XnliProcessor,
}
......
......@@ -3,7 +3,7 @@ Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""
from __future__ import (absolute_import, division, print_function, unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import sys
import json
......@@ -29,9 +29,10 @@ from filelock import FileLock
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
try:
os.environ.setdefault('USE_TORCH', 'YES')
if os.environ['USE_TORCH'].upper() in ('1', 'ON', 'YES'):
os.environ.setdefault("USE_TORCH", "YES")
if os.environ["USE_TORCH"].upper() in ("1", "ON", "YES"):
import torch
_torch_available = True # pylint: disable=invalid-name
logger.info("PyTorch version {} available.".format(torch.__version__))
else:
......@@ -41,10 +42,11 @@ except ImportError:
_torch_available = False # pylint: disable=invalid-name
try:
os.environ.setdefault('USE_TF', 'YES')
if os.environ['USE_TF'].upper() in ('1', 'ON', 'YES'):
os.environ.setdefault("USE_TF", "YES")
if os.environ["USE_TF"].upper() in ("1", "ON", "YES"):
import tensorflow as tf
assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name
logger.info("TensorFlow version {} available.".format(tf.__version__))
else:
......@@ -55,12 +57,13 @@ except (ImportError, AssertionError):
try:
from torch.hub import _get_torch_home
torch_cache_home = _get_torch_home()
except ImportError:
torch_cache_home = os.path.expanduser(
os.getenv('TORCH_HOME', os.path.join(
os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
default_cache_path = os.path.join(torch_cache_home, 'transformers')
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
)
default_cache_path = os.path.join(torch_cache_home, "transformers")
try:
from urllib.parse import urlparse
......@@ -69,19 +72,21 @@ except ImportError:
try:
from pathlib import Path
PYTORCH_PRETRAINED_BERT_CACHE = Path(
os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)))
os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path))
)
except (AttributeError, ImportError):
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE',
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
default_cache_path))
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv(
"PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
)
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
WEIGHTS_NAME = "pytorch_model.bin"
TF2_WEIGHTS_NAME = 'tf_model.h5'
TF_WEIGHTS_NAME = 'model.ckpt'
TF2_WEIGHTS_NAME = "tf_model.h5"
TF_WEIGHTS_NAME = "model.ckpt"
CONFIG_NAME = "config.json"
MODEL_CARD_NAME = "modelcard.json"
......@@ -95,38 +100,48 @@ CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net"
def is_torch_available():
return _torch_available
def is_tf_available():
return _tf_available
if not six.PY2:
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = ''.join(docstr) + fn.__doc__
fn.__doc__ = "".join(docstr) + fn.__doc__
return fn
return docstring_decorator
def add_end_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = fn.__doc__ + ''.join(docstr)
fn.__doc__ = fn.__doc__ + "".join(docstr)
return fn
return docstring_decorator
else:
# Not possible to update class docstrings on python2
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
return fn
return docstring_decorator
def add_end_docstrings(*docstr):
def docstring_decorator(fn):
return fn
return docstring_decorator
def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ('http', 'https', 's3')
return parsed.scheme in ("http", "https", "s3")
def hf_bucket_url(identifier, postfix=None, cdn=False):
endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
......@@ -145,17 +160,17 @@ def url_to_filename(url, etag=None):
so that TF 2.0 can identify it as a HDF5 file
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
"""
url_bytes = url.encode('utf-8')
url_bytes = url.encode("utf-8")
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
etag_bytes = etag.encode('utf-8')
etag_bytes = etag.encode("utf-8")
etag_hash = sha256(etag_bytes)
filename += '.' + etag_hash.hexdigest()
filename += "." + etag_hash.hexdigest()
if url.endswith('.h5'):
filename += '.h5'
if url.endswith(".h5"):
filename += ".h5"
return filename
......@@ -174,19 +189,21 @@ def filename_to_url(filename, cache_dir=None):
if not os.path.exists(cache_path):
raise EnvironmentError("file {} not found".format(cache_path))
meta_path = cache_path + '.json'
meta_path = cache_path + ".json"
if not os.path.exists(meta_path):
raise EnvironmentError("file {} not found".format(meta_path))
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata['url']
etag = metadata['etag']
url = metadata["url"]
etag = metadata["etag"]
return url, etag
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None):
def cached_path(
url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None
):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
......@@ -207,13 +224,18 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
if is_remote_url(url_or_filename):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir=cache_dir,
force_download=force_download, proxies=proxies,
resume_download=resume_download, user_agent=user_agent)
return get_from_cache(
url_or_filename,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
user_agent=user_agent,
)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
elif urlparse(url_or_filename).scheme == '':
elif urlparse(url_or_filename).scheme == "":
# File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename))
else:
......@@ -273,31 +295,35 @@ def s3_get(url, temp_file, proxies=None):
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
if isinstance(user_agent, dict):
ua += "; " + "; ".join(
"{}/{}".format(k, v) for k, v in user_agent.items()
)
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
elif isinstance(user_agent, six.string_types):
ua += "; "+ user_agent
headers = {
"user-agent": ua
}
ua += "; " + user_agent
headers = {"user-agent": ua}
if resume_size > 0:
headers['Range'] = 'bytes=%d-' % (resume_size,)
headers["Range"] = "bytes=%d-" % (resume_size,)
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
if response.status_code == 416: # Range not satisfiable
return
content_length = response.headers.get('Content-Length')
content_length = response.headers.get("Content-Length")
total = resume_size + int(content_length) if content_length is not None else None
progress = tqdm(unit="B", unit_scale=True, total=total, initial=resume_size,
desc="Downloading", disable=bool(logger.level<=logging.INFO))
progress = tqdm(
unit="B",
unit_scale=True,
total=total,
initial=resume_size,
desc="Downloading",
disable=bool(logger.level <= logging.INFO),
)
for chunk in response.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None):
def get_from_cache(
url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None
):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
......@@ -326,7 +352,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
etag = None
if sys.version_info[0] == 2 and etag is not None:
etag = etag.decode('utf-8')
etag = etag.decode("utf-8")
filename = url_to_filename(url, etag)
# get cache path to put the file
......@@ -337,22 +363,24 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
if not os.path.exists(cache_path) and etag is None:
matching_files = [
file
for file in fnmatch.filter(os.listdir(cache_dir), filename + '.*')
if not file.endswith('.json') and not file.endswith('.lock')
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
if not file.endswith(".json") and not file.endswith(".lock")
]
if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1])
# Prevent parallel downloads of the same file with a lock.
lock_path = cache_path + '.lock'
lock_path = cache_path + ".lock"
with FileLock(lock_path):
if resume_download:
incomplete_path = cache_path + '.incomplete'
incomplete_path = cache_path + ".incomplete"
@contextmanager
def _resumable_file_manager():
with open(incomplete_path,'a+b') as f:
with open(incomplete_path, "a+b") as f:
yield f
temp_file_manager = _resumable_file_manager
if os.path.exists(incomplete_path):
resume_size = os.stat(incomplete_path).st_size
......@@ -366,7 +394,9 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
logger.info(
"%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name
)
# GET file object
if url.startswith("s3://"):
......@@ -383,12 +413,12 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
os.rename(temp_file.name, cache_path)
logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
with open(meta_path, 'w') as meta_file:
meta = {"url": url, "etag": etag}
meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file:
output_string = json.dumps(meta)
if sys.version_info[0] == 2 and isinstance(output_string, str):
output_string = unicode(output_string, 'utf-8') # The beauty of python 2
output_string = unicode(output_string, "utf-8") # The beauty of python 2
meta_file.write(output_string)
return cache_path
......@@ -24,13 +24,14 @@ from tqdm import tqdm
ENDPOINT = "https://huggingface.co"
class S3Obj:
def __init__(
self,
filename, # type: str
LastModified, # type: str
ETag, # type: str
Size, # type: int
filename, # type: str
LastModified, # type: str
ETag, # type: str
Size, # type: int
**kwargs
):
self.filename = filename
......@@ -43,13 +44,13 @@ class PresignedUrl:
def __init__(
self,
write, # type: str
access, # type: str
type, # type: str
access, # type: str
type, # type: str
**kwargs
):
self.write = write
self.access = access
self.type = type # mime-type to send to S3.
self.type = type # mime-type to send to S3.
class HfApi:
......@@ -58,8 +59,8 @@ class HfApi:
def login(
self,
username, # type: str
password, # type: str
username, # type: str
password, # type: str
):
# type: (...) -> str
"""
......@@ -78,8 +79,7 @@ class HfApi:
return d["token"]
def whoami(
self,
token, # type: str
self, token, # type: str
):
# type: (...) -> str
"""
......@@ -106,11 +106,7 @@ class HfApi:
Call HF API to get a presigned url to upload `filename` to S3.
"""
path = "{}/api/presign".format(self.endpoint)
r = requests.post(
path,
headers={"authorization": "Bearer {}".format(token)},
json={"filename": filename},
)
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename},)
r.raise_for_status()
d = r.json()
return PresignedUrl(**d)
......@@ -126,16 +122,14 @@ class HfApi:
urls = self.presign(token, filename=filename)
# streaming upload:
# https://2.python-requests.org/en/master/user/advanced/#streaming-uploads
#
#
# Even though we presign with the correct content-type,
# the client still has to specify it when uploading the file.
with open(filepath, "rb") as f:
pf = TqdmProgressFileReader(f)
data = f if pf.total_size > 0 else ""
r = requests.put(urls.write, data=data, headers={
"content-type": urls.type,
})
r = requests.put(urls.write, data=data, headers={"content-type": urls.type,})
r.raise_for_status()
pf.close()
return urls.access
......@@ -152,7 +146,6 @@ class HfApi:
return [S3Obj(**x) for x in d]
class TqdmProgressFileReader:
"""
Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`)
......@@ -161,12 +154,12 @@ class TqdmProgressFileReader:
see github.com/huggingface/transformers/pull/2078#discussion_r354739608
for implementation details.
"""
def __init__(
self,
f # type: io.BufferedReader
self, f # type: io.BufferedReader
):
self.f = f
self.total_size = os.fstat(f.fileno()).st_size # type: int
self.total_size = os.fstat(f.fileno()).st_size # type: int
self.pbar = tqdm(total=self.total_size, leave=False)
if six.PY3:
# does not work unless PY3
......@@ -182,7 +175,6 @@ class TqdmProgressFileReader:
self.pbar.close()
class HfFolder:
path_token = expanduser("~/.huggingface/token")
......@@ -201,7 +193,7 @@ class HfFolder:
if e.errno != os.errno.EEXIST:
raise e
pass
with open(cls.path_token, 'w+') as f:
with open(cls.path_token, "w+") as f:
f.write(token)
@classmethod
......@@ -210,7 +202,7 @@ class HfFolder:
Get token or None if not existent.
"""
try:
with open(cls.path_token, 'r') as f:
with open(cls.path_token, "r") as f:
return f.read()
except:
# this is too wide. When Py2 is dead use:
......
......@@ -14,8 +14,7 @@
# limitations under the License.
""" Configuration base class and utilities."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import json
......@@ -25,8 +24,15 @@ from io import open
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .file_utils import CONFIG_NAME, MODEL_CARD_NAME, WEIGHTS_NAME, TF2_WEIGHTS_NAME, \
cached_path, is_remote_url, hf_bucket_url
from .file_utils import (
CONFIG_NAME,
MODEL_CARD_NAME,
WEIGHTS_NAME,
TF2_WEIGHTS_NAME,
cached_path,
is_remote_url,
hf_bucket_url,
)
logger = logging.getLogger(__name__)
......@@ -48,17 +54,18 @@ class ModelCard(object):
Parameters:
"""
def __init__(self, **kwargs):
# Recomended attributes from https://arxiv.org/abs/1810.03993 (see papers)
self.model_details = kwargs.pop('model_details', {})
self.intended_use = kwargs.pop('intended_use', {})
self.factors = kwargs.pop('factors', {})
self.metrics = kwargs.pop('metrics', {})
self.evaluation_data = kwargs.pop('evaluation_data', {})
self.training_data = kwargs.pop('training_data', {})
self.quantitative_analyses = kwargs.pop('quantitative_analyses', {})
self.ethical_considerations = kwargs.pop('ethical_considerations', {})
self.caveats_and_recommendations = kwargs.pop('caveats_and_recommendations', {})
self.model_details = kwargs.pop("model_details", {})
self.intended_use = kwargs.pop("intended_use", {})
self.factors = kwargs.pop("factors", {})
self.metrics = kwargs.pop("metrics", {})
self.evaluation_data = kwargs.pop("evaluation_data", {})
self.training_data = kwargs.pop("training_data", {})
self.quantitative_analyses = kwargs.pop("quantitative_analyses", {})
self.ethical_considerations = kwargs.pop("ethical_considerations", {})
self.caveats_and_recommendations = kwargs.pop("caveats_and_recommendations", {})
# Open additional attributes
for key, value in kwargs.items():
......@@ -122,10 +129,10 @@ class ModelCard(object):
modelcard = ModelCard.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
"""
cache_dir = kwargs.pop('cache_dir', None)
proxies = kwargs.pop('proxies', None)
find_from_standard_name = kwargs.pop('find_from_standard_name', True)
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
cache_dir = kwargs.pop("cache_dir", None)
proxies = kwargs.pop("proxies", None)
find_from_standard_name = kwargs.pop("find_from_standard_name", True)
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
# For simplicity we use the same pretrained url than the configuration files
......@@ -145,36 +152,43 @@ class ModelCard(object):
try:
# Load from URL or cache if already cached
resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, force_download=True,
proxies=proxies, resume_download=False)
resolved_model_card_file = cached_path(
model_card_file, cache_dir=cache_dir, force_download=True, proxies=proxies, resume_download=False
)
if resolved_model_card_file == model_card_file:
logger.info("loading model card file {}".format(model_card_file))
else:
logger.info("loading model card file {} from cache at {}".format(
model_card_file, resolved_model_card_file))
logger.info(
"loading model card file {} from cache at {}".format(model_card_file, resolved_model_card_file)
)
# Load model card
modelcard = cls.from_json_file(resolved_model_card_file)
except EnvironmentError:
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.warning("Couldn't reach server at '{}' to download model card file.".format(
model_card_file))
logger.warning("Couldn't reach server at '{}' to download model card file.".format(model_card_file))
else:
logger.warning("Model name '{}' was not found in model name list ({}). " \
"We assumed '{}' was a path or url to a model card file named {} or " \
"a directory containing such a file but couldn't find any such file at this path or url.".format(
logger.warning(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url to a model card file named {} or "
"a directory containing such a file but couldn't find any such file at this path or url.".format(
pretrained_model_name_or_path,
', '.join(ALL_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
model_card_file, MODEL_CARD_NAME))
", ".join(ALL_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
model_card_file,
MODEL_CARD_NAME,
)
)
logger.warning("Creating an empty model card.")
# We fall back on creating an empty model card
modelcard = cls()
except json.JSONDecodeError:
logger.warning("Couldn't reach server at '{}' to download model card file or "
"model card file is not a valid JSON file. "
"Please check network or file content here: {}.".format(model_card_file, resolved_model_card_file))
logger.warning(
"Couldn't reach server at '{}' to download model card file or "
"model card file is not a valid JSON file. "
"Please check network or file content here: {}.".format(model_card_file, resolved_model_card_file)
)
logger.warning("Creating an empty model card.")
# We fall back on creating an empty model card
......@@ -203,7 +217,7 @@ class ModelCard(object):
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `ModelCard` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
dict_obj = json.loads(text)
return cls(**dict_obj)
......@@ -225,5 +239,5 @@ class ModelCard(object):
def to_json_file(self, json_file_path):
""" Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer:
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
#
......@@ -30,14 +29,14 @@ logger = logging.getLogger(__name__)
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-pytorch_model.bin",
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-pytorch_model.bin",
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-pytorch_model.bin",
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-pytorch_model.bin",
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-pytorch_model.bin",
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-pytorch_model.bin",
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-pytorch_model.bin",
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-pytorch_model.bin",
"albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-pytorch_model.bin",
"albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-pytorch_model.bin",
"albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-pytorch_model.bin",
"albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-pytorch_model.bin",
"albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-pytorch_model.bin",
"albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-pytorch_model.bin",
"albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-pytorch_model.bin",
"albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-pytorch_model.bin",
}
......@@ -48,8 +47,10 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
import numpy as np
import tensorflow as tf
except ImportError:
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
logger.error(
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
......@@ -65,7 +66,7 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
for name, array in zip(names, arrays):
print(name)
for name, array in zip(names, arrays):
original_name = name
......@@ -75,10 +76,10 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
# Renaming and simplifying
name = name.replace("ffn_1", "ffn")
name = name.replace("bert/", "albert/")
name = name.replace("attention_1", "attention")
name = name.replace("attention_1", "attention")
name = name.replace("transform/", "")
name = name.replace("LayerNorm_1", "full_layer_layer_norm")
name = name.replace("LayerNorm", "attention/LayerNorm")
name = name.replace("LayerNorm_1", "full_layer_layer_norm")
name = name.replace("LayerNorm", "attention/LayerNorm")
name = name.replace("transformer/", "")
# The feed forward layer had an 'intermediate' step which has been abstracted away
......@@ -97,19 +98,19 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
name = name.replace("predictions/attention", "predictions")
# Naming was changed to be more explicit
name = name.replace("embeddings/attention", "embeddings")
name = name.replace("inner_group_", "albert_layers/")
name = name.replace("group_", "albert_layer_groups/")
name = name.replace("embeddings/attention", "embeddings")
name = name.replace("inner_group_", "albert_layers/")
name = name.replace("group_", "albert_layer_groups/")
# Classifier
if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
name = "classifier/" + name
# No ALBERT model currently handles the next sentence prediction task
# No ALBERT model currently handles the next sentence prediction task
if "seq_relationship" in name:
continue
name = name.split('/')
name = name.split("/")
# Ignore the gradients applied by the LAMB/ADAM optimizers.
if "adam_m" in name or "adam_v" in name or "global_step" in name:
......@@ -118,19 +119,19 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
l = re.split(r'_(\d+)', m_name)
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
l = re.split(r"_(\d+)", m_name)
else:
l = [m_name]
if l[0] == 'kernel' or l[0] == 'gamma':
pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias' or l[0] == 'beta':
pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight')
elif l[0] == 'squad':
pointer = getattr(pointer, 'classifier')
if l[0] == "kernel" or l[0] == "gamma":
pointer = getattr(pointer, "weight")
elif l[0] == "output_bias" or l[0] == "beta":
pointer = getattr(pointer, "bias")
elif l[0] == "output_weights":
pointer = getattr(pointer, "weight")
elif l[0] == "squad":
pointer = getattr(pointer, "classifier")
else:
try:
pointer = getattr(pointer, l[0])
......@@ -141,9 +142,9 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
num = int(l[1])
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
if m_name[-11:] == "_embeddings":
pointer = getattr(pointer, "weight")
elif m_name == "kernel":
array = np.transpose(array)
try:
assert pointer.shape == array.shape
......@@ -160,6 +161,7 @@ class AlbertEmbeddings(BertEmbeddings):
"""
Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(AlbertEmbeddings, self).__init__(config)
......@@ -175,7 +177,7 @@ class AlbertAttention(BertSelfAttention):
self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.hidden_size = config.hidden_size
self.attention_head_size = config.hidden_size // config.num_attention_heads
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
......@@ -237,10 +239,13 @@ class AlbertAttention(BertSelfAttention):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
reshaped_context_layer = context_layer.view(*new_context_layer_shape)
# Should find a better way to do this
w = self.dense.weight.t().view(self.num_attention_heads, self.attention_head_size, self.hidden_size).to(context_layer.dtype)
w = (
self.dense.weight.t()
.view(self.num_attention_heads, self.attention_head_size, self.hidden_size)
.to(context_layer.dtype)
)
b = self.dense.bias.to(context_layer.dtype)
projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
......@@ -252,11 +257,11 @@ class AlbertAttention(BertSelfAttention):
class AlbertLayer(nn.Module):
def __init__(self, config):
super(AlbertLayer, self).__init__()
self.config = config
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = AlbertAttention(config)
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
self.activation = ACT2FN[config.hidden_act]
......@@ -273,7 +278,7 @@ class AlbertLayer(nn.Module):
class AlbertLayerGroup(nn.Module):
def __init__(self, config):
super(AlbertLayerGroup, self).__init__()
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
......@@ -303,7 +308,7 @@ class AlbertLayerGroup(nn.Module):
class AlbertTransformer(nn.Module):
def __init__(self, config):
super(AlbertTransformer, self).__init__()
self.config = config
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
......@@ -327,8 +332,12 @@ class AlbertTransformer(nn.Module):
# Index of the layer inside the group
layer_idx = int(i - group_idx * layers_per_group)
layer_group_output = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask[group_idx*layers_per_group:(group_idx+1)*layers_per_group])
layer_group_output = self.albert_layer_groups[group_idx](
hidden_states,
attention_mask,
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
)
hidden_states = layer_group_output[0]
if self.output_attentions:
......@@ -337,7 +346,6 @@ class AlbertTransformer(nn.Module):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
......@@ -346,11 +354,11 @@ class AlbertTransformer(nn.Module):
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class AlbertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class = AlbertConfig
pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "albert"
......@@ -431,8 +439,12 @@ ALBERT_INPUTS_DOCSTRING = r"""
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@add_start_docstrings("The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.",
ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
@add_start_docstrings(
"The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.",
ALBERT_START_DOCSTRING,
ALBERT_INPUTS_DOCSTRING,
)
class AlbertModel(AlbertPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
......@@ -500,8 +512,15 @@ class AlbertModel(AlbertPreTrainedModel):
inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
inputs_embeds=None):
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
......@@ -520,31 +539,37 @@ class AlbertModel(AlbertPreTrainedModel):
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
head_mask = (
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
) # We can specify head_mask for each layer
head_mask = head_mask.to(
dtype=next(self.parameters()).dtype
) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds)
encoder_outputs = self.encoder(embedding_output,
extended_attention_mask,
head_mask=head_mask)
embedding_output = self.embeddings(
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)
encoder_outputs = self.encoder(embedding_output, extended_attention_mask, head_mask=head_mask)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0]))
outputs = (sequence_output, pooled_output) + encoder_outputs[1:] # add hidden_states and attentions if they are here
outputs = (sequence_output, pooled_output) + encoder_outputs[
1:
] # add hidden_states and attentions if they are here
return outputs
class AlbertMLMHead(nn.Module):
def __init__(self, config):
super(AlbertMLMHead, self).__init__()
......@@ -566,7 +591,9 @@ class AlbertMLMHead(nn.Module):
return prediction_scores
@add_start_docstrings("Bert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
@add_start_docstrings(
"Bert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING
)
class AlbertForMaskedLM(AlbertPreTrainedModel):
r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
......@@ -602,21 +629,28 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.predictions.decoder,
self.albert.embeddings.word_embeddings)
self._tie_or_clone_weights(self.predictions.decoder, self.albert.embeddings.word_embeddings)
def get_output_embeddings(self):
return self.predictions.decoder
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
masked_lm_labels=None):
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
masked_lm_labels=None,
):
outputs = self.albert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds
inputs_embeds=inputs_embeds,
)
sequence_outputs = outputs[0]
......@@ -631,9 +665,12 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
return outputs
@add_start_docstrings("""Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of
@add_start_docstrings(
"""Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
ALBERT_START_DOCSTRING,
ALBERT_INPUTS_DOCSTRING,
)
class AlbertForSequenceClassification(AlbertPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
......@@ -665,6 +702,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
loss, logits = outputs[:2]
"""
def __init__(self, config):
super(AlbertForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
......@@ -675,8 +713,16 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
outputs = self.albert(
input_ids=input_ids,
......@@ -684,7 +730,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds
inputs_embeds=inputs_embeds,
)
pooled_output = outputs[1]
......@@ -707,10 +753,12 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings("""Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
@add_start_docstrings(
"""Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
ALBERT_START_DOCSTRING,
ALBERT_INPUTS_DOCSTRING,
)
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
......@@ -752,6 +800,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
"""
def __init__(self, config):
super(AlbertForQuestionAnswering, self).__init__(config)
self.num_labels = config.num_labels
......@@ -761,8 +810,17 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
inputs_embeds=None, start_positions=None, end_positions=None):
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
):
outputs = self.albert(
input_ids=input_ids,
......@@ -770,7 +828,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds
inputs_embeds=inputs_embeds,
)
sequence_output = outputs[0]
......
......@@ -18,31 +18,87 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import logging
from .configuration_auto import (AlbertConfig, BertConfig, CamembertConfig, CTRLConfig,
DistilBertConfig, GPT2Config, OpenAIGPTConfig, RobertaConfig,
TransfoXLConfig, XLMConfig, XLNetConfig, XLMRobertaConfig)
from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering, \
BertForTokenClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .configuration_auto import (
AlbertConfig,
BertConfig,
CamembertConfig,
CTRLConfig,
DistilBertConfig,
GPT2Config,
OpenAIGPTConfig,
RobertaConfig,
TransfoXLConfig,
XLMConfig,
XLNetConfig,
XLMRobertaConfig,
)
from .modeling_bert import (
BertModel,
BertForMaskedLM,
BertForSequenceClassification,
BertForQuestionAnswering,
BertForTokenClassification,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_openai import OpenAIGPTModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_gpt2 import GPT2Model, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_ctrl import CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering, \
XLNetForTokenClassification, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering, \
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, \
RobertaForTokenClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, \
DistilBertForSequenceClassification, DistilBertForTokenClassification, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertForSequenceClassification, \
CamembertForMultipleChoice, CamembertForTokenClassification, CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_albert import AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, \
AlbertForQuestionAnswering, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_xlnet import (
XLNetModel,
XLNetLMHeadModel,
XLNetForSequenceClassification,
XLNetForQuestionAnswering,
XLNetForTokenClassification,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_xlm import (
XLMModel,
XLMWithLMHeadModel,
XLMForSequenceClassification,
XLMForQuestionAnswering,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_roberta import (
RobertaModel,
RobertaForMaskedLM,
RobertaForSequenceClassification,
RobertaForTokenClassification,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_distilbert import (
DistilBertModel,
DistilBertForQuestionAnswering,
DistilBertForMaskedLM,
DistilBertForSequenceClassification,
DistilBertForTokenClassification,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_camembert import (
CamembertModel,
CamembertForMaskedLM,
CamembertForSequenceClassification,
CamembertForMultipleChoice,
CamembertForTokenClassification,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_albert import (
AlbertModel,
AlbertForMaskedLM,
AlbertForSequenceClassification,
AlbertForQuestionAnswering,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_t5 import T5Model, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_xlm_roberta import XLMRobertaModel, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification, \
XLMRobertaForMultipleChoice, XLMRobertaForTokenClassification, XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_xlm_roberta import (
XLMRobertaModel,
XLMRobertaForMaskedLM,
XLMRobertaForSequenceClassification,
XLMRobertaForMultipleChoice,
XLMRobertaForTokenClassification,
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_utils import PreTrainedModel, SequenceSummary
......@@ -51,7 +107,8 @@ from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict((key, value)
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
(key, value)
for pretrained_map in [
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
......@@ -66,8 +123,9 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict((key, value)
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
]
for key, value, in pretrained_map.items())
]
for key, value, in pretrained_map.items()
)
class AutoModel(object):
......@@ -98,10 +156,13 @@ class AutoModel(object):
This class cannot be instantiated using `__init__()` (throws an error).
"""
def __init__(self):
raise EnvironmentError("AutoModel is designed to be instantiated "
raise EnvironmentError(
"AutoModel is designed to be instantiated "
"using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` or "
"`AutoModel.from_config(config)` methods.")
"`AutoModel.from_config(config)` methods."
)
@classmethod
def from_config(cls, config):
......@@ -232,35 +293,39 @@ class AutoModel(object):
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if 't5' in pretrained_model_name_or_path:
if "t5" in pretrained_model_name_or_path:
return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path:
elif "distilbert" in pretrained_model_name_or_path:
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'albert' in pretrained_model_name_or_path:
elif "albert" in pretrained_model_name_or_path:
return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'camembert' in pretrained_model_name_or_path:
elif "camembert" in pretrained_model_name_or_path:
return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm-roberta' in pretrained_model_name_or_path:
elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
elif "roberta" in pretrained_model_name_or_path:
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
elif "bert" in pretrained_model_name_or_path:
return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'openai-gpt' in pretrained_model_name_or_path:
elif "openai-gpt" in pretrained_model_name_or_path:
return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'gpt2' in pretrained_model_name_or_path:
elif "gpt2" in pretrained_model_name_or_path:
return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'transfo-xl' in pretrained_model_name_or_path:
elif "transfo-xl" in pretrained_model_name_or_path:
return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path:
elif "xlnet" in pretrained_model_name_or_path:
return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm' in pretrained_model_name_or_path:
elif "xlm" in pretrained_model_name_or_path:
return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'ctrl' in pretrained_model_name_or_path:
elif "ctrl" in pretrained_model_name_or_path:
return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta, 'ctrl', 'distilbert', 'camembert', 'albert'".format(pretrained_model_name_or_path))
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta, 'ctrl', 'distilbert', 'camembert', 'albert'".format(
pretrained_model_name_or_path
)
)
class AutoModelWithLMHead(object):
......@@ -291,10 +356,13 @@ class AutoModelWithLMHead(object):
This class cannot be instantiated using `__init__()` (throws an error).
"""
def __init__(self):
raise EnvironmentError("AutoModelWithLMHead is designed to be instantiated "
raise EnvironmentError(
"AutoModelWithLMHead is designed to be instantiated "
"using the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` or "
"`AutoModelWithLMHead.from_config(config)` methods.")
"`AutoModelWithLMHead.from_config(config)` methods."
)
@classmethod
def from_config(cls, config):
......@@ -423,35 +491,39 @@ class AutoModelWithLMHead(object):
model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if 't5' in pretrained_model_name_or_path:
if "t5" in pretrained_model_name_or_path:
return T5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path:
elif "distilbert" in pretrained_model_name_or_path:
return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'albert' in pretrained_model_name_or_path:
elif "albert" in pretrained_model_name_or_path:
return AlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'camembert' in pretrained_model_name_or_path:
elif "camembert" in pretrained_model_name_or_path:
return CamembertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm-roberta' in pretrained_model_name_or_path:
elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
elif "roberta" in pretrained_model_name_or_path:
return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
elif "bert" in pretrained_model_name_or_path:
return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'openai-gpt' in pretrained_model_name_or_path:
elif "openai-gpt" in pretrained_model_name_or_path:
return OpenAIGPTLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'gpt2' in pretrained_model_name_or_path:
elif "gpt2" in pretrained_model_name_or_path:
return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'transfo-xl' in pretrained_model_name_or_path:
elif "transfo-xl" in pretrained_model_name_or_path:
return TransfoXLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path:
elif "xlnet" in pretrained_model_name_or_path:
return XLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm' in pretrained_model_name_or_path:
elif "xlm" in pretrained_model_name_or_path:
return XLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'ctrl' in pretrained_model_name_or_path:
elif "ctrl" in pretrained_model_name_or_path:
return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta','ctrl', 'distilbert', 'camembert', 'albert'".format(pretrained_model_name_or_path))
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta','ctrl', 'distilbert', 'camembert', 'albert'".format(
pretrained_model_name_or_path
)
)
class AutoModelForSequenceClassification(object):
......@@ -477,10 +549,13 @@ class AutoModelForSequenceClassification(object):
This class cannot be instantiated using `__init__()` (throws an error).
"""
def __init__(self):
raise EnvironmentError("AutoModelForSequenceClassification is designed to be instantiated "
raise EnvironmentError(
"AutoModelForSequenceClassification is designed to be instantiated "
"using the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)` or "
"`AutoModelForSequenceClassification.from_config(config)` methods.")
"`AutoModelForSequenceClassification.from_config(config)` methods."
)
@classmethod
def from_config(cls, config):
......@@ -597,25 +672,39 @@ class AutoModelForSequenceClassification(object):
model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if 'distilbert' in pretrained_model_name_or_path:
return DistilBertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'albert' in pretrained_model_name_or_path:
return AlbertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'camembert' in pretrained_model_name_or_path:
return CamembertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm-roberta' in pretrained_model_name_or_path:
return XLMRobertaForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
return RobertaForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
if "distilbert" in pretrained_model_name_or_path:
return DistilBertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
elif "albert" in pretrained_model_name_or_path:
return AlbertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
elif "camembert" in pretrained_model_name_or_path:
return CamembertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
elif "roberta" in pretrained_model_name_or_path:
return RobertaForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
elif "bert" in pretrained_model_name_or_path:
return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path:
elif "xlnet" in pretrained_model_name_or_path:
return XLNetForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm' in pretrained_model_name_or_path:
elif "xlm" in pretrained_model_name_or_path:
return XLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'albert'".format(pretrained_model_name_or_path))
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'albert'".format(
pretrained_model_name_or_path
)
)
class AutoModelForQuestionAnswering(object):
......@@ -638,10 +727,13 @@ class AutoModelForQuestionAnswering(object):
This class cannot be instantiated using `__init__()` (throws an error).
"""
def __init__(self):
raise EnvironmentError("AutoModelForQuestionAnswering is designed to be instantiated "
raise EnvironmentError(
"AutoModelForQuestionAnswering is designed to be instantiated "
"using the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or "
"`AutoModelForQuestionAnswering.from_config(config)` methods.")
"`AutoModelForQuestionAnswering.from_config(config)` methods."
)
@classmethod
def from_config(cls, config):
......@@ -745,26 +837,30 @@ class AutoModelForQuestionAnswering(object):
model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if 'distilbert' in pretrained_model_name_or_path:
if "distilbert" in pretrained_model_name_or_path:
return DistilBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'albert' in pretrained_model_name_or_path:
elif "albert" in pretrained_model_name_or_path:
return AlbertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
elif "bert" in pretrained_model_name_or_path:
return BertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path:
elif "xlnet" in pretrained_model_name_or_path:
return XLNetForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm' in pretrained_model_name_or_path:
elif "xlm" in pretrained_model_name_or_path:
return XLMForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'xlm', 'distilbert', 'albert'".format(pretrained_model_name_or_path))
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'xlm', 'distilbert', 'albert'".format(pretrained_model_name_or_path)
)
class AutoModelForTokenClassification:
def __init__(self):
raise EnvironmentError("AutoModelForTokenClassification is designed to be instantiated "
"using the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)` or "
"`AutoModelForTokenClassification.from_config(config)` methods.")
raise EnvironmentError(
"AutoModelForTokenClassification is designed to be instantiated "
"using the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)` or "
"`AutoModelForTokenClassification.from_config(config)` methods."
)
@classmethod
def from_config(cls, config):
......@@ -797,7 +893,7 @@ class AutoModelForTokenClassification:
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaForTokenClassification(config)
raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the question answering model classes of the library
......@@ -870,18 +966,28 @@ class AutoModelForTokenClassification:
model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if 'camembert' in pretrained_model_name_or_path:
return CamembertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path:
return DistilBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm-roberta' in pretrained_model_name_or_path:
return XLMRobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
if "camembert" in pretrained_model_name_or_path:
return CamembertForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
elif "distilbert" in pretrained_model_name_or_path:
return DistilBertForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
elif "roberta" in pretrained_model_name_or_path:
return RobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
elif "bert" in pretrained_model_name_or_path:
return BertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path:
elif "xlnet" in pretrained_model_name_or_path:
return XLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'camembert', 'distilbert', 'xlm-roberta', 'roberta'".format(pretrained_model_name_or_path))
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'camembert', 'distilbert', 'xlm-roberta', 'roberta'".format(
pretrained_model_name_or_path
)
)
......@@ -33,27 +33,27 @@ from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-pytorch_model.bin",
'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-pytorch_model.bin",
'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-pytorch_model.bin",
'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-pytorch_model.bin",
'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-pytorch_model.bin",
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-pytorch_model.bin",
'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/pytorch_model.bin",
'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/pytorch_model.bin",
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
"bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
"bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
"bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
"bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
"bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
"bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
"bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
"bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
"bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
"bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
"bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-pytorch_model.bin",
"bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-pytorch_model.bin",
"bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-pytorch_model.bin",
"bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-pytorch_model.bin",
"bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-pytorch_model.bin",
"bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-pytorch_model.bin",
"bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/pytorch_model.bin",
"bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/pytorch_model.bin",
}
......@@ -65,8 +65,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
import numpy as np
import tensorflow as tf
except ImportError:
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
logger.error(
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
......@@ -81,7 +83,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
arrays.append(array)
for name, array in zip(names, arrays):
name = name.split('/')
name = name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
......@@ -89,18 +91,18 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
continue
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
l = re.split(r'_(\d+)', m_name)
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
l = re.split(r"_(\d+)", m_name)
else:
l = [m_name]
if l[0] == 'kernel' or l[0] == 'gamma':
pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias' or l[0] == 'beta':
pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight')
elif l[0] == 'squad':
pointer = getattr(pointer, 'classifier')
if l[0] == "kernel" or l[0] == "gamma":
pointer = getattr(pointer, "weight")
elif l[0] == "output_bias" or l[0] == "beta":
pointer = getattr(pointer, "bias")
elif l[0] == "output_weights":
pointer = getattr(pointer, "weight")
elif l[0] == "squad":
pointer = getattr(pointer, "classifier")
else:
try:
pointer = getattr(pointer, l[0])
......@@ -110,9 +112,9 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
if m_name[-11:] == "_embeddings":
pointer = getattr(pointer, "weight")
elif m_name == "kernel":
array = np.transpose(array)
try:
assert pointer.shape == array.shape
......@@ -157,6 +159,7 @@ BertLayerNorm = torch.nn.LayerNorm
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
......@@ -199,7 +202,8 @@ class BertSelfAttention(nn.Module):
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads
......@@ -217,7 +221,14 @@ class BertSelfAttention(nn.Module):
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
......@@ -307,8 +318,17 @@ class BertAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
self_outputs = self.self(
hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
......@@ -353,13 +373,22 @@ class BertLayer(nn.Module):
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
if self.is_decoder and encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
cross_attention_outputs = self.crossattention(
attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
......@@ -376,14 +405,23 @@ class BertEncoder(nn.Module):
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask)
layer_outputs = layer_module(
hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask
)
hidden_states = layer_outputs[0]
if self.output_attentions:
......@@ -440,9 +478,7 @@ class BertLMPredictionHead(nn.Module):
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size,
config.vocab_size,
bias=False)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
......@@ -488,6 +524,7 @@ class BertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class = BertConfig
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_bert
......@@ -581,8 +618,12 @@ BERT_INPUTS_DOCSTRING = r"""
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
"""
@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
@add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING,
)
class BertModel(BertPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
......@@ -612,6 +653,7 @@ class BertModel(BertPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config):
super(BertModel, self).__init__(config)
self.config = config
......@@ -636,8 +678,17 @@ class BertModel(BertPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None,
head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None):
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
""" Forward pass on the Model.
The model can behave as an encoder (with only self-attention) as well
......@@ -681,12 +732,18 @@ class BertModel(BertPreTrainedModel):
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
causal_mask = causal_mask.to(torch.long) # not converting to long will cause errors with pytorch version < 1.3
causal_mask = causal_mask.to(
torch.long
) # not converting to long will cause errors with pytorch version < 1.3
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape))
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
......@@ -709,10 +766,15 @@ class BertModel(BertPreTrainedModel):
elif encoder_attention_mask.dim() == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
else:
raise ValueError("Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(encoder_hidden_shape,
encoder_attention_mask.shape))
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
raise ValueError(
"Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(
encoder_hidden_shape, encoder_attention_mask.shape
)
)
encoder_extended_attention_mask = encoder_extended_attention_mask.to(
dtype=next(self.parameters()).dtype
) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
else:
encoder_extended_attention_mask = None
......@@ -727,28 +789,40 @@ class BertModel(BertPreTrainedModel):
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
head_mask = (
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
) # We can specify head_mask for each layer
head_mask = head_mask.to(
dtype=next(self.parameters()).dtype
) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds)
encoder_outputs = self.encoder(embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask)
embedding_output = self.embeddings(
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
outputs = (sequence_output, pooled_output,) + encoder_outputs[
1:
] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
@add_start_docstrings(
"""Bert Model with two heads on top as done during the pre-training:
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING)
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING,
)
class BertForPreTraining(BertPreTrainedModel):
r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
......@@ -786,6 +860,7 @@ class BertForPreTraining(BertPreTrainedModel):
prediction_scores, seq_relationship_scores = outputs[:2]
"""
def __init__(self, config):
super(BertForPreTraining, self).__init__(config)
......@@ -797,20 +872,33 @@ class BertForPreTraining(BertPreTrainedModel):
def get_output_embeddings(self):
return self.cls.predictions.decoder
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
masked_lm_labels=None, next_sentence_label=None):
outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
masked_lm_labels=None,
next_sentence_label=None,
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
outputs = (prediction_scores, seq_relationship_score,) + outputs[
2:
] # add hidden states and attention if they are here
if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss()
......@@ -822,9 +910,9 @@ class BertForPreTraining(BertPreTrainedModel):
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """,
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING)
@add_start_docstrings(
"""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING
)
class BertForMaskedLM(BertPreTrainedModel):
r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
......@@ -862,6 +950,7 @@ class BertForMaskedLM(BertPreTrainedModel):
loss, prediction_scores = outputs[:2]
"""
def __init__(self, config):
super(BertForMaskedLM, self).__init__(config)
......@@ -873,17 +962,30 @@ class BertForMaskedLM(BertPreTrainedModel):
def get_output_embeddings(self):
return self.cls.predictions.decoder
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
masked_lm_labels=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
lm_labels=None,
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
......@@ -912,9 +1014,11 @@ class BertForMaskedLM(BertPreTrainedModel):
return outputs # (masked_lm_loss), (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING)
@add_start_docstrings(
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING,
)
class BertForNextSentencePrediction(BertPreTrainedModel):
r"""
**next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
......@@ -945,6 +1049,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_scores = outputs[0]
"""
def __init__(self, config):
super(BertForNextSentencePrediction, self).__init__(config)
......@@ -953,15 +1058,25 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
next_sentence_label=None):
outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
next_sentence_label=None,
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
pooled_output = outputs[1]
......@@ -976,10 +1091,12 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
@add_start_docstrings(
"""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING)
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING,
)
class BertForSequenceClassification(BertPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
......@@ -1011,6 +1128,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
loss, logits = outputs[:2]
"""
def __init__(self, config):
super(BertForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
......@@ -1021,15 +1139,25 @@ class BertForSequenceClassification(BertPreTrainedModel):
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
pooled_output = outputs[1]
......@@ -1051,10 +1179,12 @@ class BertForSequenceClassification(BertPreTrainedModel):
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
@add_start_docstrings(
"""Bert Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING)
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING,
)
class BertForMultipleChoice(BertPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
......@@ -1087,6 +1217,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
loss, classification_scores = outputs[:2]
"""
def __init__(self, config):
super(BertForMultipleChoice, self).__init__(config)
......@@ -1096,8 +1227,16 @@ class BertForMultipleChoice(BertPreTrainedModel):
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
num_choices = input_ids.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1))
......@@ -1105,12 +1244,14 @@ class BertForMultipleChoice(BertPreTrainedModel):
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
pooled_output = outputs[1]
......@@ -1128,10 +1269,12 @@ class BertForMultipleChoice(BertPreTrainedModel):
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
@add_start_docstrings(
"""Bert Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING)
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING,
)
class BertForTokenClassification(BertPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
......@@ -1161,6 +1304,7 @@ class BertForTokenClassification(BertPreTrainedModel):
loss, scores = outputs[:2]
"""
def __init__(self, config):
super(BertForTokenClassification, self).__init__(config)
self.num_labels = config.num_labels
......@@ -1171,15 +1315,25 @@ class BertForTokenClassification(BertPreTrainedModel):
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
sequence_output = outputs[0]
......@@ -1202,10 +1356,12 @@ class BertForTokenClassification(BertPreTrainedModel):
return outputs # (loss), scores, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
@add_start_docstrings(
"""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING)
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING,
)
class BertForQuestionAnswering(BertPreTrainedModel):
r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
......@@ -1247,6 +1403,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
"""
def __init__(self, config):
super(BertForQuestionAnswering, self).__init__(config)
self.num_labels = config.num_labels
......@@ -1256,15 +1413,26 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
start_positions=None, end_positions=None):
outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
sequence_output = outputs[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment