Unverified Commit 54abc67a authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2255 from aaugustin/implement-best-practices

Implement some Python best practices
parents 645713e2 c11b3e29
...@@ -14,18 +14,19 @@ ...@@ -14,18 +14,19 @@
# limitations under the License. # limitations under the License.
"""Convert BERT checkpoint.""" """Convert BERT checkpoint."""
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import argparse import argparse
import logging
import torch import torch
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
# Initialise PyTorch model # Initialise PyTorch model
config = BertConfig.from_json_file(bert_config_file) config = BertConfig.from_json_file(bert_config_file)
...@@ -42,24 +43,20 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor ...@@ -42,24 +43,20 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument("--tf_checkpoint_path", parser.add_argument(
default = None, "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
type = str, )
required = True, parser.add_argument(
help = "Path to the TensorFlow checkpoint path.") "--bert_config_file",
parser.add_argument("--bert_config_file", default=None,
default = None, type=str,
type = str, required=True,
required = True, help="The config json file corresponding to the pre-trained BERT model. \n"
help = "The config json file corresponding to the pre-trained BERT model. \n" "This specifies the model architecture.",
"This specifies the model architecture.") )
parser.add_argument("--pytorch_dump_path", parser.add_argument(
default = None, "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
type = str, )
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args() args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
args.bert_config_file,
args.pytorch_dump_path)
...@@ -15,15 +15,17 @@ ...@@ -15,15 +15,17 @@
"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint.""" """Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
import os
import argparse import argparse
import torch import os
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import torch
from transformers import BertModel from transformers import BertModel
def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str): def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
""" """
:param model:BertModel Pytorch model instance to be converted :param model:BertModel Pytorch model instance to be converted
...@@ -41,22 +43,17 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s ...@@ -41,22 +43,17 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
N BertForQuestionAnswering N BertForQuestionAnswering
""" """
tensors_to_transpose = ( tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")
"dense.weight",
"attention.self.query",
"attention.self.key",
"attention.self.value"
)
var_map = ( var_map = (
('layer.', 'layer_'), ("layer.", "layer_"),
('word_embeddings.weight', 'word_embeddings'), ("word_embeddings.weight", "word_embeddings"),
('position_embeddings.weight', 'position_embeddings'), ("position_embeddings.weight", "position_embeddings"),
('token_type_embeddings.weight', 'token_type_embeddings'), ("token_type_embeddings.weight", "token_type_embeddings"),
('.', '/'), (".", "/"),
('LayerNorm/weight', 'LayerNorm/gamma'), ("LayerNorm/weight", "LayerNorm/gamma"),
('LayerNorm/bias', 'LayerNorm/beta'), ("LayerNorm/bias", "LayerNorm/beta"),
('weight', 'kernel') ("weight", "kernel"),
) )
if not os.path.isdir(ckpt_dir): if not os.path.isdir(ckpt_dir):
...@@ -64,12 +61,12 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s ...@@ -64,12 +61,12 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
state_dict = model.state_dict() state_dict = model.state_dict()
def to_tf_var_name(name:str): def to_tf_var_name(name: str):
for patt, repl in iter(var_map): for patt, repl in iter(var_map):
name = name.replace(patt, repl) name = name.replace(patt, repl)
return 'bert/{}'.format(name) return "bert/{}".format(name)
def create_tf_var(tensor:np.ndarray, name:str, session:tf.Session): def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
tf_dtype = tf.dtypes.as_dtype(tensor.dtype) tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer()) tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
session.run(tf.variables_initializer([tf_var])) session.run(tf.variables_initializer([tf_var]))
...@@ -94,37 +91,22 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s ...@@ -94,37 +91,22 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
def main(raw_args=None): def main(raw_args=None):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_name", parser.add_argument("--model_name", type=str, required=True, help="model name e.g. bert-base-uncased")
type=str, parser.add_argument(
required=True, "--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model"
help="model name e.g. bert-base-uncased") )
parser.add_argument("--cache_dir", parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/<pytorch-model-name>.bin")
type=str, parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model")
default=None,
required=False,
help="Directory containing pytorch model")
parser.add_argument("--pytorch_model_path",
type=str,
required=True,
help="/path/to/<pytorch-model-name>.bin")
parser.add_argument("--tf_cache_dir",
type=str,
required=True,
help="Directory in which to save tensorflow model")
args = parser.parse_args(raw_args) args = parser.parse_args(raw_args)
model = BertModel.from_pretrained( model = BertModel.from_pretrained(
pretrained_model_name_or_path=args.model_name, pretrained_model_name_or_path=args.model_name,
state_dict=torch.load(args.pytorch_model_path), state_dict=torch.load(args.pytorch_model_path),
cache_dir=args.cache_dir cache_dir=args.cache_dir,
)
convert_pytorch_checkpoint_to_tf(
model=model,
ckpt_dir=args.tf_cache_dir,
model_name=args.model_name
) )
convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -17,16 +17,14 @@ ...@@ -17,16 +17,14 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import argparse import argparse
import logging
from io import open from io import open
import torch import torch
from transformers import (CONFIG_NAME, WEIGHTS_NAME, from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2
GPT2Config,
GPT2Model,
load_tf_weights_in_gpt2)
import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -42,8 +40,8 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p ...@@ -42,8 +40,8 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)
# Save pytorch-model # Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model.state_dict(), pytorch_weights_dump_path) torch.save(model.state_dict(), pytorch_weights_dump_path)
print("Save configuration file to {}".format(pytorch_config_dump_path)) print("Save configuration file to {}".format(pytorch_config_dump_path))
...@@ -53,23 +51,19 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p ...@@ -53,23 +51,19 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument("--gpt2_checkpoint_path", parser.add_argument(
default = None, "--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
type = str, )
required = True, parser.add_argument(
help = "Path to the TensorFlow checkpoint path.") "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
parser.add_argument("--pytorch_dump_folder_path", )
default = None, parser.add_argument(
type = str, "--gpt2_config_file",
required = True, default="",
help = "Path to the output PyTorch model.") type=str,
parser.add_argument("--gpt2_config_file", help="An optional config json file corresponding to the pre-trained OpenAI model. \n"
default = "", "This specifies the model architecture.",
type = str, )
help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
"This specifies the model architecture.")
args = parser.parse_args() args = parser.parse_args()
convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, args.gpt2_config_file, args.pytorch_dump_folder_path)
args.gpt2_config_file,
args.pytorch_dump_folder_path)
...@@ -17,16 +17,14 @@ ...@@ -17,16 +17,14 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import argparse import argparse
import logging
from io import open from io import open
import torch import torch
from transformers import (CONFIG_NAME, WEIGHTS_NAME, from transformers import CONFIG_NAME, WEIGHTS_NAME, OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
OpenAIGPTConfig,
OpenAIGPTModel,
load_tf_weights_in_openai_gpt)
import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -42,8 +40,8 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c ...@@ -42,8 +40,8 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path)
# Save pytorch-model # Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model.state_dict(), pytorch_weights_dump_path) torch.save(model.state_dict(), pytorch_weights_dump_path)
print("Save configuration file to {}".format(pytorch_config_dump_path)) print("Save configuration file to {}".format(pytorch_config_dump_path))
...@@ -53,23 +51,25 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c ...@@ -53,23 +51,25 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument("--openai_checkpoint_folder_path", parser.add_argument(
default = None, "--openai_checkpoint_folder_path",
type = str, default=None,
required = True, type=str,
help = "Path to the TensorFlow checkpoint path.") required=True,
parser.add_argument("--pytorch_dump_folder_path", help="Path to the TensorFlow checkpoint path.",
default = None, )
type = str, parser.add_argument(
required = True, "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
help = "Path to the output PyTorch model.") )
parser.add_argument("--openai_config_file", parser.add_argument(
default = "", "--openai_config_file",
type = str, default="",
help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" type=str,
"This specifies the model architecture.") help="An optional config json file corresponding to the pre-trained OpenAI model. \n"
"This specifies the model architecture.",
)
args = parser.parse_args() args = parser.parse_args()
convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, convert_openai_checkpoint_to_pytorch(
args.openai_config_file, args.openai_checkpoint_folder_path, args.openai_config_file, args.pytorch_dump_folder_path
args.pytorch_dump_folder_path) )
...@@ -14,92 +14,276 @@ ...@@ -14,92 +14,276 @@
# limitations under the License. # limitations under the License.
""" Convert pytorch checkpoints to TensorFlow """ """ Convert pytorch checkpoints to TensorFlow """
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import os
import argparse import argparse
import tensorflow as tf import logging
import os
from transformers import is_torch_available, cached_path
from transformers import (
from transformers import (load_pytorch_checkpoint_in_tf2_model, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2Config, TFGPT2LMHeadModel, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNetConfig, TFXLNetLMHeadModel, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLMConfig, TFXLMWithLMHeadModel, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, TFDistilBertForSequenceClassification, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
AlbertConfig, TFAlbertForMaskedLM, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5Config, TFT5WithLMHeadModel, T5_PRETRAINED_CONFIG_ARCHIVE_MAP) AlbertConfig,
BertConfig,
CTRLConfig,
DistilBertConfig,
GPT2Config,
OpenAIGPTConfig,
RobertaConfig,
T5Config,
TFAlbertForMaskedLM,
TFBertForPreTraining,
TFBertForQuestionAnswering,
TFBertForSequenceClassification,
TFCTRLLMHeadModel,
TFDistilBertForMaskedLM,
TFDistilBertForQuestionAnswering,
TFGPT2LMHeadModel,
TFOpenAIGPTLMHeadModel,
TFRobertaForMaskedLM,
TFRobertaForSequenceClassification,
TFT5WithLMHeadModel,
TFTransfoXLLMHeadModel,
TFXLMWithLMHeadModel,
TFXLNetLMHeadModel,
TransfoXLConfig,
XLMConfig,
XLNetConfig,
cached_path,
is_torch_available,
load_pytorch_checkpoint_in_tf2_model,
)
if is_torch_available(): if is_torch_available():
import torch import torch
import numpy as np import numpy as np
from transformers import (BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, from transformers import (
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, BertForPreTraining,
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, BertForQuestionAnswering,
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, BertForSequenceClassification,
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, XLNetLMHeadModel,
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, XLMWithLMHeadModel,
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP) XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
TransfoXLLMHeadModel,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM,
RobertaForSequenceClassification,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM,
DistilBertForQuestionAnswering,
DistilBertForSequenceClassification,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5WithLMHeadModel,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
)
else: else:
(BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, (
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, BertForPreTraining,
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, BertForQuestionAnswering,
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, BertForSequenceClassification,
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForSequenceClassification, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, XLNetLMHeadModel,
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, XLMWithLMHeadModel,
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP) = ( XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
None, None, None, None, TransfoXLLMHeadModel,
None, None, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
None, None, OpenAIGPTLMHeadModel,
None, None, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
None, None, RobertaForMaskedLM,
None, None, RobertaForSequenceClassification,
None, None, None, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
None, None, None, None, DistilBertForMaskedLM,
None, None, DistilBertForSequenceClassification,
None, None, DistilBertForQuestionAnswering,
None, None) DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5WithLMHeadModel,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
) = (
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), "bert": (
'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), BertConfig,
'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), TFBertForPreTraining,
'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), BertForPreTraining,
'gpt2': (GPT2Config, TFGPT2LMHeadModel, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP), BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP), BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
'xlm': (XLMConfig, TFXLMWithLMHeadModel, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP), ),
'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP), "bert-large-uncased-whole-word-masking-finetuned-squad": (
'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP), BertConfig,
'roberta': (RobertaConfig, TFRobertaForMaskedLM, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), TFBertForQuestionAnswering,
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), BertForQuestionAnswering,
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), ),
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP), "bert-large-cased-whole-word-masking-finetuned-squad": (
'albert': (AlbertConfig, TFAlbertForMaskedLM, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), BertConfig,
't5': (T5Config, TFT5WithLMHeadModel, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP), TFBertForQuestionAnswering,
BertForQuestionAnswering,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"bert-base-cased-finetuned-mrpc": (
BertConfig,
TFBertForSequenceClassification,
BertForSequenceClassification,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"gpt2": (
GPT2Config,
TFGPT2LMHeadModel,
GPT2LMHeadModel,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"xlnet": (
XLNetConfig,
TFXLNetLMHeadModel,
XLNetLMHeadModel,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"xlm": (
XLMConfig,
TFXLMWithLMHeadModel,
XLMWithLMHeadModel,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"transfo-xl": (
TransfoXLConfig,
TFTransfoXLLMHeadModel,
TransfoXLLMHeadModel,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"openai-gpt": (
OpenAIGPTConfig,
TFOpenAIGPTLMHeadModel,
OpenAIGPTLMHeadModel,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"roberta": (
RobertaConfig,
TFRobertaForMaskedLM,
RobertaForMaskedLM,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"roberta-large-mnli": (
RobertaConfig,
TFRobertaForSequenceClassification,
RobertaForSequenceClassification,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"distilbert": (
DistilBertConfig,
TFDistilBertForMaskedLM,
DistilBertForMaskedLM,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"distilbert-base-uncased-distilled-squad": (
DistilBertConfig,
TFDistilBertForQuestionAnswering,
DistilBertForQuestionAnswering,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"distilbert-base-uncased-distilled-squad": (
DistilBertConfig,
TFDistilBertForQuestionAnswering,
DistilBertForQuestionAnswering,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"ctrl": (
CTRLConfig,
TFCTRLLMHeadModel,
CTRLLMHeadModel,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"albert": (
AlbertConfig,
TFAlbertForMaskedLM,
AlbertForMaskedLM,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"t5": (
T5Config,
TFT5WithLMHeadModel,
T5WithLMHeadModel,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
} }
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
def convert_pt_checkpoint_to_tf(
model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True
):
if model_type not in MODEL_CLASSES: if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys()))) raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
...@@ -116,17 +300,19 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file ...@@ -116,17 +300,19 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
# Load weights from tf checkpoint # Load weights from tf checkpoint
if pytorch_checkpoint_path in aws_model_maps: if pytorch_checkpoint_path in aws_model_maps:
pytorch_checkpoint_path = cached_path(aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models) pytorch_checkpoint_path = cached_path(
aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models
)
# Load PyTorch checkpoint in tf2 model: # Load PyTorch checkpoint in tf2 model:
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path) tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
if compare_with_pt_model: if compare_with_pt_model:
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
state_dict = torch.load(pytorch_checkpoint_path, map_location='cpu') state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
pt_model = pt_model_class.from_pretrained(pretrained_model_name_or_path=None, pt_model = pt_model_class.from_pretrained(
config=config, pretrained_model_name_or_path=None, config=config, state_dict=state_dict
state_dict=state_dict) )
with torch.no_grad(): with torch.no_grad():
pto = pt_model(**pt_model.dummy_inputs) pto = pt_model(**pt_model.dummy_inputs)
...@@ -139,11 +325,19 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file ...@@ -139,11 +325,19 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
# Save pytorch-model # Save pytorch-model
print("Save TensorFlow model to {}".format(tf_dump_path)) print("Save TensorFlow model to {}".format(tf_dump_path))
tf_model.save_weights(tf_dump_path, save_format='h5') tf_model.save_weights(tf_dump_path, save_format="h5")
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None, def convert_all_pt_checkpoints_to_tf(
compare_with_pt_model=False, use_cached_models=False, remove_cached_files=False, only_convert_finetuned_models=False): args_model_type,
tf_dump_path,
model_shortcut_names_or_path=None,
config_shortcut_names_or_path=None,
compare_with_pt_model=False,
use_cached_models=False,
remove_cached_files=False,
only_convert_finetuned_models=False,
):
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory" assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
if args_model_type is None: if args_model_type is None:
...@@ -156,7 +350,9 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc ...@@ -156,7 +350,9 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
print(" Converting model type {}/{}: {}".format(j, len(model_types), model_type)) print(" Converting model type {}/{}: {}".format(j, len(model_types), model_type))
print("=" * 100) print("=" * 100)
if model_type not in MODEL_CLASSES: if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys()))) raise ValueError(
"Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys()))
)
config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type] config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
...@@ -166,9 +362,10 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc ...@@ -166,9 +362,10 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
config_shortcut_names_or_path = model_shortcut_names_or_path config_shortcut_names_or_path = model_shortcut_names_or_path
for i, (model_shortcut_name, config_shortcut_name) in enumerate( for i, (model_shortcut_name, config_shortcut_name) in enumerate(
zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1): zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1
):
print("-" * 100) print("-" * 100)
if '-squad' in model_shortcut_name or '-mrpc' in model_shortcut_name or '-mnli' in model_shortcut_name: if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name:
if not only_convert_finetuned_models: if not only_convert_finetuned_models:
print(" Skipping finetuned checkpoint {}".format(model_shortcut_name)) print(" Skipping finetuned checkpoint {}".format(model_shortcut_name))
continue continue
...@@ -176,7 +373,11 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc ...@@ -176,7 +373,11 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
elif only_convert_finetuned_models: elif only_convert_finetuned_models:
print(" Skipping not finetuned checkpoint {}".format(model_shortcut_name)) print(" Skipping not finetuned checkpoint {}".format(model_shortcut_name))
continue continue
print(" Converting checkpoint {}/{}: {} - model_type {}".format(i, len(aws_config_map), model_shortcut_name, model_type)) print(
" Converting checkpoint {}/{}: {} - model_type {}".format(
i, len(aws_config_map), model_shortcut_name, model_type
)
)
print("-" * 100) print("-" * 100)
if config_shortcut_name in aws_config_map: if config_shortcut_name in aws_config_map:
...@@ -190,13 +391,15 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc ...@@ -190,13 +391,15 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
model_file = cached_path(model_shortcut_name, force_download=not use_cached_models) model_file = cached_path(model_shortcut_name, force_download=not use_cached_models)
if os.path.isfile(model_shortcut_name): if os.path.isfile(model_shortcut_name):
model_shortcut_name = 'converted_model' model_shortcut_name = "converted_model"
convert_pt_checkpoint_to_tf(model_type=model_type, convert_pt_checkpoint_to_tf(
pytorch_checkpoint_path=model_file, model_type=model_type,
config_file=config_file, pytorch_checkpoint_path=model_file,
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'), config_file=config_file,
compare_with_pt_model=compare_with_pt_model) tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"),
compare_with_pt_model=compare_with_pt_model,
)
if remove_cached_files: if remove_cached_files:
os.remove(config_file) os.remove(config_file)
os.remove(model_file) os.remove(model_file)
...@@ -204,40 +407,48 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc ...@@ -204,40 +407,48 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument("--tf_dump_path", parser.add_argument(
default = None, "--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
type = str, )
required = True, parser.add_argument(
help = "Path to the output Tensorflow dump file.") "--model_type",
parser.add_argument("--model_type", default=None,
default = None, type=str,
type = str, help="Model type selected in the list of {}. If not given, will download and convert all the models from AWS.".format(
help = "Model type selected in the list of {}. If not given, will download and convert all the models from AWS.".format(list(MODEL_CLASSES.keys()))) list(MODEL_CLASSES.keys())
parser.add_argument("--pytorch_checkpoint_path", ),
default = None, )
type = str, parser.add_argument(
help = "Path to the PyTorch checkpoint path or shortcut name to download from AWS. " "--pytorch_checkpoint_path",
"If not given, will download and convert all the checkpoints from AWS.") default=None,
parser.add_argument("--config_file", type=str,
default = None, help="Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
type = str, "If not given, will download and convert all the checkpoints from AWS.",
help = "The config json file corresponding to the pre-trained model. \n" )
"This specifies the model architecture. If not given and " parser.add_argument(
"--pytorch_checkpoint_path is not given or is a shortcut name" "--config_file",
"use the configuration associated to the shortcut name on the AWS") default=None,
parser.add_argument("--compare_with_pt_model", type=str,
action='store_true', help="The config json file corresponding to the pre-trained model. \n"
help = "Compare Tensorflow and PyTorch model predictions.") "This specifies the model architecture. If not given and "
parser.add_argument("--use_cached_models", "--pytorch_checkpoint_path is not given or is a shortcut name"
action='store_true', "use the configuration associated to the shortcut name on the AWS",
help = "Use cached models if possible instead of updating to latest checkpoint versions.") )
parser.add_argument("--remove_cached_files", parser.add_argument(
action='store_true', "--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions."
help = "Remove pytorch models after conversion (save memory when converting in batches).") )
parser.add_argument("--only_convert_finetuned_models", parser.add_argument(
action='store_true', "--use_cached_models",
help = "Only convert finetuned models.") action="store_true",
help="Use cached models if possible instead of updating to latest checkpoint versions.",
)
parser.add_argument(
"--remove_cached_files",
action="store_true",
help="Remove pytorch models after conversion (save memory when converting in batches).",
)
parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.")
args = parser.parse_args() args = parser.parse_args()
# if args.pytorch_checkpoint_path is not None: # if args.pytorch_checkpoint_path is not None:
...@@ -248,11 +459,15 @@ if __name__ == "__main__": ...@@ -248,11 +459,15 @@ if __name__ == "__main__":
# compare_with_pt_model=args.compare_with_pt_model, # compare_with_pt_model=args.compare_with_pt_model,
# use_cached_models=args.use_cached_models) # use_cached_models=args.use_cached_models)
# else: # else:
convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None, convert_all_pt_checkpoints_to_tf(
args.tf_dump_path, args.model_type.lower() if args.model_type is not None else None,
model_shortcut_names_or_path=[args.pytorch_checkpoint_path] if args.pytorch_checkpoint_path is not None else None, args.tf_dump_path,
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None, model_shortcut_names_or_path=[args.pytorch_checkpoint_path]
compare_with_pt_model=args.compare_with_pt_model, if args.pytorch_checkpoint_path is not None
use_cached_models=args.use_cached_models, else None,
remove_cached_files=args.remove_cached_files, config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
only_convert_finetuned_models=args.only_convert_finetuned_models) compare_with_pt_model=args.compare_with_pt_model,
use_cached_models=args.use_cached_models,
remove_cached_files=args.remove_cached_files,
only_convert_finetuned_models=args.only_convert_finetuned_models,
)
...@@ -18,32 +18,33 @@ from __future__ import absolute_import, division, print_function ...@@ -18,32 +18,33 @@ from __future__ import absolute_import, division, print_function
import argparse import argparse
import logging import logging
import numpy as np
import torch
import pathlib import pathlib
import fairseq import fairseq
import torch
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer
from packaging import version from packaging import version
from transformers.modeling_bert import (
BertConfig,
BertIntermediate,
BertLayer,
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
from transformers.modeling_roberta import RobertaForMaskedLM, RobertaForSequenceClassification
if version.parse(fairseq.__version__) < version.parse("0.9.0"): if version.parse(fairseq.__version__) < version.parse("0.9.0"):
raise Exception("requires fairseq >= 0.9.0") raise Exception("requires fairseq >= 0.9.0")
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer
from transformers.modeling_bert import (BertConfig, BertEncoder,
BertIntermediate, BertLayer,
BertModel, BertOutput,
BertSelfAttention,
BertSelfOutput)
from transformers.modeling_roberta import (RobertaEmbeddings,
RobertaForMaskedLM,
RobertaForSequenceClassification,
RobertaModel)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SAMPLE_TEXT = 'Hello world! cécé herlolip' SAMPLE_TEXT = "Hello world! cécé herlolip"
def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path, classification_head): def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path, classification_head):
...@@ -61,7 +62,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -61,7 +62,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
intermediate_size=roberta.args.encoder_ffn_embed_dim, intermediate_size=roberta.args.encoder_ffn_embed_dim,
max_position_embeddings=514, max_position_embeddings=514,
type_vocab_size=1, type_vocab_size=1,
layer_norm_eps=1e-5, # PyTorch default used in fairseq layer_norm_eps=1e-5, # PyTorch default used in fairseq
) )
if classification_head: if classification_head:
config.num_labels = roberta.args.num_classes config.num_labels = roberta.args.num_classes
...@@ -74,7 +75,9 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -74,7 +75,9 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
# Embeddings # Embeddings
model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight
model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight
model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them. model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(
model.roberta.embeddings.token_type_embeddings.weight
) # just zero them out b/c RoBERTa doesn't use them.
model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight
model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias
...@@ -83,13 +86,13 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -83,13 +86,13 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
layer: BertLayer = model.roberta.encoder.layer[i] layer: BertLayer = model.roberta.encoder.layer[i]
roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i] roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i]
### self attention # self attention
self_attn: BertSelfAttention = layer.attention.self self_attn: BertSelfAttention = layer.attention.self
assert( assert (
roberta_layer.self_attn.k_proj.weight.data.shape == \ roberta_layer.self_attn.k_proj.weight.data.shape
roberta_layer.self_attn.q_proj.weight.data.shape == \ == roberta_layer.self_attn.q_proj.weight.data.shape
roberta_layer.self_attn.v_proj.weight.data.shape == \ == roberta_layer.self_attn.v_proj.weight.data.shape
torch.Size((config.hidden_size, config.hidden_size)) == torch.Size((config.hidden_size, config.hidden_size))
) )
self_attn.query.weight.data = roberta_layer.self_attn.q_proj.weight self_attn.query.weight.data = roberta_layer.self_attn.q_proj.weight
...@@ -99,40 +102,34 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -99,40 +102,34 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
self_attn.value.weight.data = roberta_layer.self_attn.v_proj.weight self_attn.value.weight.data = roberta_layer.self_attn.v_proj.weight
self_attn.value.bias.data = roberta_layer.self_attn.v_proj.bias self_attn.value.bias.data = roberta_layer.self_attn.v_proj.bias
### self-attention output # self-attention output
self_output: BertSelfOutput = layer.attention.output self_output: BertSelfOutput = layer.attention.output
assert( assert self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
)
self_output.dense.weight = roberta_layer.self_attn.out_proj.weight self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
self_output.dense.bias = roberta_layer.self_attn.out_proj.bias self_output.dense.bias = roberta_layer.self_attn.out_proj.bias
self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias
### intermediate # intermediate
intermediate: BertIntermediate = layer.intermediate intermediate: BertIntermediate = layer.intermediate
assert( assert intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
)
intermediate.dense.weight = roberta_layer.fc1.weight intermediate.dense.weight = roberta_layer.fc1.weight
intermediate.dense.bias = roberta_layer.fc1.bias intermediate.dense.bias = roberta_layer.fc1.bias
### output # output
bert_output: BertOutput = layer.output bert_output: BertOutput = layer.output
assert( assert bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
)
bert_output.dense.weight = roberta_layer.fc2.weight bert_output.dense.weight = roberta_layer.fc2.weight
bert_output.dense.bias = roberta_layer.fc2.bias bert_output.dense.bias = roberta_layer.fc2.bias
bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias
#### end of layer # end of layer
if classification_head: if classification_head:
model.classifier.dense.weight = roberta.model.classification_heads['mnli'].dense.weight model.classifier.dense.weight = roberta.model.classification_heads["mnli"].dense.weight
model.classifier.dense.bias = roberta.model.classification_heads['mnli'].dense.bias model.classifier.dense.bias = roberta.model.classification_heads["mnli"].dense.bias
model.classifier.out_proj.weight = roberta.model.classification_heads['mnli'].out_proj.weight model.classifier.out_proj.weight = roberta.model.classification_heads["mnli"].out_proj.weight
model.classifier.out_proj.bias = roberta.model.classification_heads['mnli'].out_proj.bias model.classifier.out_proj.bias = roberta.model.classification_heads["mnli"].out_proj.bias
else: else:
# LM Head # LM Head
model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight
...@@ -143,21 +140,18 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -143,21 +140,18 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
model.lm_head.bias = roberta.model.decoder.lm_head.bias model.lm_head.bias = roberta.model.decoder.lm_head.bias
# Let's check that we get the same results. # Let's check that we get the same results.
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1 input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
our_output = model(input_ids)[0] our_output = model(input_ids)[0]
if classification_head: if classification_head:
their_output = roberta.model.classification_heads['mnli'](roberta.extract_features(input_ids)) their_output = roberta.model.classification_heads["mnli"](roberta.extract_features(input_ids))
else: else:
their_output = roberta.model(input_ids)[0] their_output = roberta.model(input_ids)[0]
print(our_output.shape, their_output.shape) print(our_output.shape, their_output.shape)
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item() max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7 print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
success = torch.allclose(our_output, their_output, atol=1e-3) success = torch.allclose(our_output, their_output, atol=1e-3)
print( print("Do both models output the same tensors?", "🔥" if success else "💩")
"Do both models output the same tensors?",
"🔥" if success else "💩"
)
if not success: if not success:
raise Exception("Something went wRoNg") raise Exception("Something went wRoNg")
...@@ -168,24 +162,17 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -168,24 +162,17 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument("--roberta_checkpoint_path", parser.add_argument(
default = None, "--roberta_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
type = str, )
required = True, parser.add_argument(
help = "Path the official PyTorch dump.") "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
parser.add_argument("--pytorch_dump_folder_path", )
default = None, parser.add_argument(
type = str, "--classification_head", action="store_true", help="Whether to convert a final classification head."
required = True, )
help = "Path to the output PyTorch model.")
parser.add_argument("--classification_head",
action = "store_true",
help = "Whether to convert a final classification head.")
args = parser.parse_args() args = parser.parse_args()
convert_roberta_checkpoint_to_pytorch( convert_roberta_checkpoint_to_pytorch(
args.roberta_checkpoint_path, args.roberta_checkpoint_path, args.pytorch_dump_folder_path, args.classification_head
args.pytorch_dump_folder_path,
args.classification_head
) )
...@@ -14,18 +14,19 @@ ...@@ -14,18 +14,19 @@
# limitations under the License. # limitations under the License.
"""Convert T5 checkpoint.""" """Convert T5 checkpoint."""
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import argparse import argparse
import logging
import torch import torch
from transformers import T5Config, T5Model, load_tf_weights_in_t5 from transformers import T5Config, T5Model, load_tf_weights_in_t5
import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
# Initialise PyTorch model # Initialise PyTorch model
config = T5Config.from_json_file(config_file) config = T5Config.from_json_file(config_file)
...@@ -42,24 +43,20 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du ...@@ -42,24 +43,20 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument("--tf_checkpoint_path", parser.add_argument(
default = None, "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
type = str, )
required = True, parser.add_argument(
help = "Path to the TensorFlow checkpoint path.") "--config_file",
parser.add_argument("--config_file", default=None,
default = None, type=str,
type = str, required=True,
required = True, help="The config json file corresponding to the pre-trained T5 model. \n"
help = "The config json file corresponding to the pre-trained T5 model. \n" "This specifies the model architecture.",
"This specifies the model architecture.") )
parser.add_argument("--pytorch_dump_path", parser.add_argument(
default = None, "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
type = str, )
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args() args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
args.config_file,
args.pytorch_dump_path)
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import argparse import argparse
import logging
import os import os
import sys import sys
from io import open from io import open
...@@ -24,44 +25,48 @@ from io import open ...@@ -24,44 +25,48 @@ from io import open
import torch import torch
import transformers.tokenization_transfo_xl as data_utils import transformers.tokenization_transfo_xl as data_utils
from transformers import (
CONFIG_NAME,
WEIGHTS_NAME,
TransfoXLConfig,
TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl,
)
from transformers.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
from transformers import CONFIG_NAME, WEIGHTS_NAME
from transformers import (TransfoXLConfig, TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl)
from transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES)
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import cPickle as pickle import cPickle as pickle
else: else:
import pickle import pickle
import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# We do this to be able to load python 2 datasets pickles # We do this to be able to load python 2 datasets pickles
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
data_utils.Vocab = data_utils.TransfoXLTokenizer data_utils.Vocab = data_utils.TransfoXLTokenizer
data_utils.Corpus = data_utils.TransfoXLCorpus data_utils.Corpus = data_utils.TransfoXLCorpus
sys.modules['data_utils'] = data_utils sys.modules["data_utils"] = data_utils
sys.modules['vocabulary'] = data_utils sys.modules["vocabulary"] = data_utils
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, def convert_transfo_xl_checkpoint_to_pytorch(
transfo_xl_config_file, tf_checkpoint_path, transfo_xl_config_file, pytorch_dump_folder_path, transfo_xl_dataset_file
pytorch_dump_folder_path, ):
transfo_xl_dataset_file):
if transfo_xl_dataset_file: if transfo_xl_dataset_file:
# Convert a pre-processed corpus (see original TensorFlow repo) # Convert a pre-processed corpus (see original TensorFlow repo)
with open(transfo_xl_dataset_file, "rb") as fp: with open(transfo_xl_dataset_file, "rb") as fp:
corpus = pickle.load(fp, encoding="latin1") corpus = pickle.load(fp, encoding="latin1")
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file'] pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["pretrained_vocab_file"]
print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
corpus_vocab_dict = corpus.vocab.__dict__ corpus_vocab_dict = corpus.vocab.__dict__
torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
corpus_dict_no_vocab = corpus.__dict__ corpus_dict_no_vocab = corpus.__dict__
corpus_dict_no_vocab.pop('vocab', None) corpus_dict_no_vocab.pop("vocab", None)
pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME pytorch_dataset_dump_path = pytorch_dump_folder_path + "/" + CORPUS_NAME
print("Save dataset to {}".format(pytorch_dataset_dump_path)) print("Save dataset to {}".format(pytorch_dataset_dump_path))
torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
...@@ -92,26 +97,36 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, ...@@ -92,26 +97,36 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", parser.add_argument(
default = None, "--pytorch_dump_folder_path",
type = str, default=None,
required = True, type=str,
help = "Path to the folder to store the PyTorch model or dataset/vocab.") required=True,
parser.add_argument("--tf_checkpoint_path", help="Path to the folder to store the PyTorch model or dataset/vocab.",
default = "", )
type = str, parser.add_argument(
help = "An optional path to a TensorFlow checkpoint path to be converted.") "--tf_checkpoint_path",
parser.add_argument("--transfo_xl_config_file", default="",
default = "", type=str,
type = str, help="An optional path to a TensorFlow checkpoint path to be converted.",
help = "An optional config json file corresponding to the pre-trained BERT model. \n" )
"This specifies the model architecture.") parser.add_argument(
parser.add_argument("--transfo_xl_dataset_file", "--transfo_xl_config_file",
default = "", default="",
type = str, type=str,
help = "An optional dataset file to be converted in a vocabulary.") help="An optional config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.",
)
parser.add_argument(
"--transfo_xl_dataset_file",
default="",
type=str,
help="An optional dataset file to be converted in a vocabulary.",
)
args = parser.parse_args() args = parser.parse_args()
convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, convert_transfo_xl_checkpoint_to_pytorch(
args.transfo_xl_config_file, args.tf_checkpoint_path,
args.pytorch_dump_folder_path, args.transfo_xl_config_file,
args.transfo_xl_dataset_file) args.pytorch_dump_folder_path,
args.transfo_xl_dataset_file,
)
...@@ -18,41 +18,43 @@ from __future__ import absolute_import, division, print_function ...@@ -18,41 +18,43 @@ from __future__ import absolute_import, division, print_function
import argparse import argparse
import json import json
import logging
from io import open from io import open
import torch
import numpy import numpy
import torch
from transformers import CONFIG_NAME, WEIGHTS_NAME from transformers import CONFIG_NAME, WEIGHTS_NAME
from transformers.tokenization_xlm import VOCAB_FILES_NAMES from transformers.tokenization_xlm import VOCAB_FILES_NAMES
import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
# Load checkpoint # Load checkpoint
chkpt = torch.load(xlm_checkpoint_path, map_location='cpu') chkpt = torch.load(xlm_checkpoint_path, map_location="cpu")
state_dict = chkpt['model'] state_dict = chkpt["model"]
# We have the base model one level deeper than the original XLM repository # We have the base model one level deeper than the original XLM repository
two_levels_state_dict = {} two_levels_state_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
if 'pred_layer' in k: if "pred_layer" in k:
two_levels_state_dict[k] = v two_levels_state_dict[k] = v
else: else:
two_levels_state_dict['transformer.' + k] = v two_levels_state_dict["transformer." + k] = v
config = chkpt['params'] config = chkpt["params"]
config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray))) config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray)))
vocab = chkpt['dico_word2id'] vocab = chkpt["dico_word2id"]
vocab = dict((s + '</w>' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in vocab.items()) vocab = dict((s + "</w>" if s.find("@@") == -1 and i > 13 else s.replace("@@", ""), i) for s, i in vocab.items())
# Save pytorch-model # Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file'] pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["vocab_file"]
print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(two_levels_state_dict, pytorch_weights_dump_path) torch.save(two_levels_state_dict, pytorch_weights_dump_path)
...@@ -68,16 +70,12 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p ...@@ -68,16 +70,12 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument("--xlm_checkpoint_path", parser.add_argument(
default = None, "--xlm_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
type = str, )
required = True, parser.add_argument(
help = "Path the official PyTorch dump.") "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
parser.add_argument("--pytorch_dump_folder_path", )
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args() args = parser.parse_args()
convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path) convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path)
...@@ -14,19 +14,24 @@ ...@@ -14,19 +14,24 @@
# limitations under the License. # limitations under the License.
"""Convert BERT checkpoint.""" """Convert BERT checkpoint."""
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import os
import argparse import argparse
import logging
import os
import torch import torch
from transformers import (CONFIG_NAME, WEIGHTS_NAME, from transformers import (
XLNetConfig, CONFIG_NAME,
XLNetLMHeadModel, XLNetForQuestionAnswering, WEIGHTS_NAME,
XLNetForSequenceClassification, XLNetConfig,
load_tf_weights_in_xlnet) XLNetForQuestionAnswering,
XLNetForSequenceClassification,
XLNetLMHeadModel,
load_tf_weights_in_xlnet,
)
GLUE_TASKS_NUM_LABELS = { GLUE_TASKS_NUM_LABELS = {
"cola": 2, "cola": 2,
...@@ -40,10 +45,13 @@ GLUE_TASKS_NUM_LABELS = { ...@@ -40,10 +45,13 @@ GLUE_TASKS_NUM_LABELS = {
"wnli": 2, "wnli": 2,
} }
import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None):
def convert_xlnet_checkpoint_to_pytorch(
tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None
):
# Initialise PyTorch model # Initialise PyTorch model
config = XLNetConfig.from_json_file(bert_config_file) config = XLNetConfig.from_json_file(bert_config_file)
...@@ -53,7 +61,7 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py ...@@ -53,7 +61,7 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
config.finetuning_task = finetuning_task config.finetuning_task = finetuning_task
config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task] config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task]
model = XLNetForSequenceClassification(config) model = XLNetForSequenceClassification(config)
elif 'squad' in finetuning_task: elif "squad" in finetuning_task:
config.finetuning_task = finetuning_task config.finetuning_task = finetuning_task
model = XLNetForQuestionAnswering(config) model = XLNetForQuestionAnswering(config)
else: else:
...@@ -74,31 +82,34 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py ...@@ -74,31 +82,34 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument("--tf_checkpoint_path", parser.add_argument(
default = None, "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
type = str, )
required = True, parser.add_argument(
help = "Path to the TensorFlow checkpoint path.") "--xlnet_config_file",
parser.add_argument("--xlnet_config_file", default=None,
default = None, type=str,
type = str, required=True,
required = True, help="The config json file corresponding to the pre-trained XLNet model. \n"
help = "The config json file corresponding to the pre-trained XLNet model. \n" "This specifies the model architecture.",
"This specifies the model architecture.") )
parser.add_argument("--pytorch_dump_folder_path", parser.add_argument(
default = None, "--pytorch_dump_folder_path",
type = str, default=None,
required = True, type=str,
help = "Path to the folder to store the PyTorch model or dataset/vocab.") required=True,
parser.add_argument("--finetuning_task", help="Path to the folder to store the PyTorch model or dataset/vocab.",
default = None, )
type = str, parser.add_argument(
help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned") "--finetuning_task",
default=None,
type=str,
help="Name of a task on which the XLNet TensorFloaw model was fine-tuned",
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path, convert_xlnet_checkpoint_to_pytorch(
args.xlnet_config_file, args.tf_checkpoint_path, args.xlnet_config_file, args.pytorch_dump_folder_path, args.finetuning_task
args.pytorch_dump_folder_path, )
args.finetuning_task)
from .processors import InputExample, InputFeatures, DataProcessor, SquadFeatures, SingleSentenceClassificationProcessor # flake8: noqa
from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features # There's no way to ignore "F401 '...' imported but unused" warnings in this
from .processors import squad_convert_examples_to_features, SquadExample, SquadV1Processor, SquadV2Processor # module, but to preserve other warnings. So, don't check this module at all.
from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
from .metrics import is_sklearn_available from .metrics import is_sklearn_available
from .processors import (
DataProcessor,
InputExample,
InputFeatures,
SingleSentenceClassificationProcessor,
SquadExample,
SquadFeatures,
SquadV1Processor,
SquadV2Processor,
glue_convert_examples_to_features,
glue_output_modes,
glue_processors,
glue_tasks_num_labels,
squad_convert_examples_to_features,
xnli_output_modes,
xnli_processors,
xnli_tasks_num_labels,
)
if is_sklearn_available(): if is_sklearn_available():
from .metrics import glue_compute_metrics, xnli_compute_metrics from .metrics import glue_compute_metrics, xnli_compute_metrics
...@@ -14,29 +14,30 @@ ...@@ -14,29 +14,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import csv
import sys
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
from scipy.stats import pearsonr, spearmanr from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score from sklearn.metrics import matthews_corrcoef, f1_score
_has_sklearn = True _has_sklearn = True
except (AttributeError, ImportError) as e: except (AttributeError, ImportError) as e:
logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html") logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html")
_has_sklearn = False _has_sklearn = False
def is_sklearn_available(): def is_sklearn_available():
return _has_sklearn return _has_sklearn
if _has_sklearn: if _has_sklearn:
def simple_accuracy(preds, labels): def simple_accuracy(preds, labels):
return (preds == labels).mean() return (preds == labels).mean()
def acc_and_f1(preds, labels): def acc_and_f1(preds, labels):
acc = simple_accuracy(preds, labels) acc = simple_accuracy(preds, labels)
f1 = f1_score(y_true=labels, y_pred=preds) f1 = f1_score(y_true=labels, y_pred=preds)
...@@ -46,7 +47,6 @@ if _has_sklearn: ...@@ -46,7 +47,6 @@ if _has_sklearn:
"acc_and_f1": (acc + f1) / 2, "acc_and_f1": (acc + f1) / 2,
} }
def pearson_and_spearman(preds, labels): def pearson_and_spearman(preds, labels):
pearson_corr = pearsonr(preds, labels)[0] pearson_corr = pearsonr(preds, labels)[0]
spearman_corr = spearmanr(preds, labels)[0] spearman_corr = spearmanr(preds, labels)[0]
...@@ -56,7 +56,6 @@ if _has_sklearn: ...@@ -56,7 +56,6 @@ if _has_sklearn:
"corr": (pearson_corr + spearman_corr) / 2, "corr": (pearson_corr + spearman_corr) / 2,
} }
def glue_compute_metrics(task_name, preds, labels): def glue_compute_metrics(task_name, preds, labels):
assert len(preds) == len(labels) assert len(preds) == len(labels)
if task_name == "cola": if task_name == "cola":
...@@ -82,7 +81,6 @@ if _has_sklearn: ...@@ -82,7 +81,6 @@ if _has_sklearn:
else: else:
raise KeyError(task_name) raise KeyError(task_name)
def xnli_compute_metrics(task_name, preds, labels): def xnli_compute_metrics(task_name, preds, labels):
assert len(preds) == len(labels) assert len(preds) == len(labels)
if task_name == "xnli": if task_name == "xnli":
......
...@@ -8,35 +8,37 @@ that a question is unanswerable. ...@@ -8,35 +8,37 @@ that a question is unanswerable.
""" """
import collections
import json import json
import logging import logging
import math import math
import collections
from io import open
from tqdm import tqdm
import string
import re import re
import string
from io import open
from transformers.tokenization_bert import BasicTokenizer
from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def normalize_answer(s): def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace.""" """Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text): def remove_articles(text):
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, ' ', text) return re.sub(regex, " ", text)
def white_space_fix(text): def white_space_fix(text):
return ' '.join(text.split()) return " ".join(text.split())
def remove_punc(text): def remove_punc(text):
exclude = set(string.punctuation) exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude) return "".join(ch for ch in text if ch not in exclude)
def lower(text): def lower(text):
return text.lower() return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s)))) return white_space_fix(remove_articles(remove_punc(lower(s))))
...@@ -75,14 +77,14 @@ def get_raw_scores(examples, preds): ...@@ -75,14 +77,14 @@ def get_raw_scores(examples, preds):
for example in examples: for example in examples:
qas_id = example.qas_id qas_id = example.qas_id
gold_answers = [answer['text'] for answer in example.answers if normalize_answer(answer['text'])] gold_answers = [answer["text"] for answer in example.answers if normalize_answer(answer["text"])]
if not gold_answers: if not gold_answers:
# For unanswerable questions, only correct answer is empty string # For unanswerable questions, only correct answer is empty string
gold_answers = [''] gold_answers = [""]
if qas_id not in preds: if qas_id not in preds:
print('Missing prediction for %s' % qas_id) print("Missing prediction for %s" % qas_id)
continue continue
prediction = preds[qas_id] prediction = preds[qas_id]
...@@ -106,23 +108,27 @@ def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): ...@@ -106,23 +108,27 @@ def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
def make_eval_dict(exact_scores, f1_scores, qid_list=None): def make_eval_dict(exact_scores, f1_scores, qid_list=None):
if not qid_list: if not qid_list:
total = len(exact_scores) total = len(exact_scores)
return collections.OrderedDict([ return collections.OrderedDict(
('exact', 100.0 * sum(exact_scores.values()) / total), [
('f1', 100.0 * sum(f1_scores.values()) / total), ("exact", 100.0 * sum(exact_scores.values()) / total),
('total', total), ("f1", 100.0 * sum(f1_scores.values()) / total),
]) ("total", total),
]
)
else: else:
total = len(qid_list) total = len(qid_list)
return collections.OrderedDict([ return collections.OrderedDict(
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total), [
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total), ("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
('total', total), ("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
]) ("total", total),
]
)
def merge_eval(main_eval, new_eval, prefix): def merge_eval(main_eval, new_eval, prefix):
for k in new_eval: for k in new_eval:
main_eval['%s_%s' % (prefix, k)] = new_eval[k] main_eval["%s_%s" % (prefix, k)] = new_eval[k]
def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans): def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
...@@ -160,16 +166,14 @@ def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans): ...@@ -160,16 +166,14 @@ def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2( best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
preds, exact_raw, na_probs, qid_to_has_ans) best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2( main_eval["best_exact"] = best_exact
preds, f1_raw, na_probs, qid_to_has_ans) main_eval["best_exact_thresh"] = exact_thresh
main_eval['best_exact'] = best_exact main_eval["best_f1"] = best_f1
main_eval['best_exact_thresh'] = exact_thresh main_eval["best_f1_thresh"] = f1_thresh
main_eval['best_f1'] = best_f1 main_eval["has_ans_exact"] = has_ans_exact
main_eval['best_f1_thresh'] = f1_thresh main_eval["has_ans_f1"] = has_ans_f1
main_eval['has_ans_exact'] = has_ans_exact
main_eval['has_ans_f1'] = has_ans_f1
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
...@@ -199,10 +203,10 @@ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_h ...@@ -199,10 +203,10 @@ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_h
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
main_eval['best_exact'] = best_exact main_eval["best_exact"] = best_exact
main_eval['best_exact_thresh'] = exact_thresh main_eval["best_exact_thresh"] = exact_thresh
main_eval['best_f1'] = best_f1 main_eval["best_f1"] = best_f1
main_eval['best_f1_thresh'] = f1_thresh main_eval["best_f1_thresh"] = f1_thresh
def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0): def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0):
...@@ -215,18 +219,20 @@ def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_ ...@@ -215,18 +219,20 @@ def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_
exact, f1 = get_raw_scores(examples, preds) exact, f1 = get_raw_scores(examples, preds)
exact_threshold = apply_no_ans_threshold(exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold) exact_threshold = apply_no_ans_threshold(
exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold
)
f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold) f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
evaluation = make_eval_dict(exact_threshold, f1_threshold) evaluation = make_eval_dict(exact_threshold, f1_threshold)
if has_answer_qids: if has_answer_qids:
has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids) has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids)
merge_eval(evaluation, has_ans_eval, 'HasAns') merge_eval(evaluation, has_ans_eval, "HasAns")
if no_answer_qids: if no_answer_qids:
no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids) no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids)
merge_eval(evaluation, no_ans_eval, 'NoAns') merge_eval(evaluation, no_ans_eval, "NoAns")
if no_answer_probs: if no_answer_probs:
find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer) find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer)
...@@ -284,8 +290,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): ...@@ -284,8 +290,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
start_position = tok_text.find(pred_text) start_position = tok_text.find(pred_text)
if start_position == -1: if start_position == -1:
if verbose_logging: if verbose_logging:
logger.info( logger.info("Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
return orig_text return orig_text
end_position = start_position + len(pred_text) - 1 end_position = start_position + len(pred_text) - 1
...@@ -294,8 +299,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): ...@@ -294,8 +299,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
if len(orig_ns_text) != len(tok_ns_text): if len(orig_ns_text) != len(tok_ns_text):
if verbose_logging: if verbose_logging:
logger.info("Length not equal after stripping spaces: '%s' vs '%s'", logger.info("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text)
orig_ns_text, tok_ns_text)
return orig_text return orig_text
# We then project the characters in `pred_text` back to `orig_text` using # We then project the characters in `pred_text` back to `orig_text` using
...@@ -326,7 +330,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): ...@@ -326,7 +330,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
logger.info("Couldn't map end position") logger.info("Couldn't map end position")
return orig_text return orig_text
output_text = orig_text[orig_start_position:(orig_end_position + 1)] output_text = orig_text[orig_start_position : (orig_end_position + 1)]
return output_text return output_text
...@@ -393,8 +397,8 @@ def compute_predictions_logits( ...@@ -393,8 +397,8 @@ def compute_predictions_logits(
unique_id_to_result[result.unique_id] = result unique_id_to_result[result.unique_id] = result
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
"PrelimPrediction", "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]
["feature_index", "start_index", "end_index", "start_logit", "end_logit"]) )
all_predictions = collections.OrderedDict() all_predictions = collections.OrderedDict()
all_nbest_json = collections.OrderedDict() all_nbest_json = collections.OrderedDict()
...@@ -447,7 +451,9 @@ def compute_predictions_logits( ...@@ -447,7 +451,9 @@ def compute_predictions_logits(
start_index=start_index, start_index=start_index,
end_index=end_index, end_index=end_index,
start_logit=result.start_logits[start_index], start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index])) end_logit=result.end_logits[end_index],
)
)
if version_2_with_negative: if version_2_with_negative:
prelim_predictions.append( prelim_predictions.append(
_PrelimPrediction( _PrelimPrediction(
...@@ -455,14 +461,14 @@ def compute_predictions_logits( ...@@ -455,14 +461,14 @@ def compute_predictions_logits(
start_index=0, start_index=0,
end_index=0, end_index=0,
start_logit=null_start_logit, start_logit=null_start_logit,
end_logit=null_end_logit)) end_logit=null_end_logit,
prelim_predictions = sorted( )
prelim_predictions, )
key=lambda x: (x.start_logit + x.end_logit), prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
reverse=True)
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
"NbestPrediction", ["text", "start_logit", "end_logit"]) "NbestPrediction", ["text", "start_logit", "end_logit"]
)
seen_predictions = {} seen_predictions = {}
nbest = [] nbest = []
...@@ -471,10 +477,10 @@ def compute_predictions_logits( ...@@ -471,10 +477,10 @@ def compute_predictions_logits(
break break
feature = features[pred.feature_index] feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction if pred.start_index > 0: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index] orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index] orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
tok_text = tokenizer.convert_tokens_to_string(tok_tokens) tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
...@@ -498,31 +504,21 @@ def compute_predictions_logits( ...@@ -498,31 +504,21 @@ def compute_predictions_logits(
final_text = "" final_text = ""
seen_predictions[final_text] = True seen_predictions[final_text] = True
nbest.append( nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit))
_NbestPrediction(
text=final_text,
start_logit=pred.start_logit,
end_logit=pred.end_logit))
# if we didn't include the empty option in the n-best, include it # if we didn't include the empty option in the n-best, include it
if version_2_with_negative: if version_2_with_negative:
if "" not in seen_predictions: if "" not in seen_predictions:
nbest.append( nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit))
_NbestPrediction(
text="",
start_logit=null_start_logit,
end_logit=null_end_logit))
# In very rare edge cases we could only have single null prediction. # In very rare edge cases we could only have single null prediction.
# So we just create a nonce prediction in this case to avoid failure. # So we just create a nonce prediction in this case to avoid failure.
if len(nbest) == 1: if len(nbest) == 1:
nbest.insert(0, nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
# In very rare edge cases we could have no valid predictions. So we # In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure. # just create a nonce prediction in this case to avoid failure.
if not nbest: if not nbest:
nbest.append( nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
assert len(nbest) >= 1 assert len(nbest) >= 1
...@@ -551,8 +547,7 @@ def compute_predictions_logits( ...@@ -551,8 +547,7 @@ def compute_predictions_logits(
all_predictions[example.qas_id] = nbest_json[0]["text"] all_predictions[example.qas_id] = nbest_json[0]["text"]
else: else:
# predict "" iff the null score - the score of best non-null > threshold # predict "" iff the null score - the score of best non-null > threshold
score_diff = score_null - best_non_null_entry.start_logit - ( score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit)
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold: if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = "" all_predictions[example.qas_id] = ""
...@@ -586,7 +581,7 @@ def compute_predictions_log_probs( ...@@ -586,7 +581,7 @@ def compute_predictions_log_probs(
end_n_top, end_n_top,
version_2_with_negative, version_2_with_negative,
tokenizer, tokenizer,
verbose_logging verbose_logging,
): ):
""" XLNet write prediction logic (more complex than Bert's). """ XLNet write prediction logic (more complex than Bert's).
Write final predictions to the json file and log-odds of null if needed. Write final predictions to the json file and log-odds of null if needed.
...@@ -594,12 +589,12 @@ def compute_predictions_log_probs( ...@@ -594,12 +589,12 @@ def compute_predictions_log_probs(
Requires utils_squad_evaluate.py Requires utils_squad_evaluate.py
""" """
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
"PrelimPrediction", "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_log_prob", "end_log_prob"]
["feature_index", "start_index", "end_index", )
"start_log_prob", "end_log_prob"])
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"]) "NbestPrediction", ["text", "start_log_prob", "end_log_prob"]
)
logger.info("Writing predictions to: %s", output_prediction_file) logger.info("Writing predictions to: %s", output_prediction_file)
# logger.info("Writing nbest to: %s" % (output_nbest_file)) # logger.info("Writing nbest to: %s" % (output_nbest_file))
...@@ -663,12 +658,13 @@ def compute_predictions_log_probs( ...@@ -663,12 +658,13 @@ def compute_predictions_log_probs(
start_index=start_index, start_index=start_index,
end_index=end_index, end_index=end_index,
start_log_prob=start_log_prob, start_log_prob=start_log_prob,
end_log_prob=end_log_prob)) end_log_prob=end_log_prob,
)
)
prelim_predictions = sorted( prelim_predictions = sorted(
prelim_predictions, prelim_predictions, key=lambda x: (x.start_log_prob + x.end_log_prob), reverse=True
key=lambda x: (x.start_log_prob + x.end_log_prob), )
reverse=True)
seen_predictions = {} seen_predictions = {}
nbest = [] nbest = []
...@@ -688,10 +684,10 @@ def compute_predictions_log_probs( ...@@ -688,10 +684,10 @@ def compute_predictions_log_probs(
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip() # final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
# Previously used Bert untokenizer # Previously used Bert untokenizer
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index] orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index] orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
tok_text = tokenizer.convert_tokens_to_string(tok_tokens) tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
# Clean whitespace # Clean whitespace
...@@ -704,8 +700,7 @@ def compute_predictions_log_probs( ...@@ -704,8 +700,7 @@ def compute_predictions_log_probs(
else: else:
do_lower_case = tokenizer.do_lowercase_and_remove_accent do_lower_case = tokenizer.do_lowercase_and_remove_accent
final_text = get_final_text(tok_text, orig_text, do_lower_case, final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
verbose_logging)
if final_text in seen_predictions: if final_text in seen_predictions:
continue continue
...@@ -713,17 +708,13 @@ def compute_predictions_log_probs( ...@@ -713,17 +708,13 @@ def compute_predictions_log_probs(
seen_predictions[final_text] = True seen_predictions[final_text] = True
nbest.append( nbest.append(
_NbestPrediction( _NbestPrediction(text=final_text, start_log_prob=pred.start_log_prob, end_log_prob=pred.end_log_prob)
text=final_text, )
start_log_prob=pred.start_log_prob,
end_log_prob=pred.end_log_prob))
# In very rare edge cases we could have no valid predictions. So we # In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure. # just create a nonce prediction in this case to avoid failure.
if not nbest: if not nbest:
nbest.append( nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
_NbestPrediction(text="", start_log_prob=-1e6,
end_log_prob=-1e6))
total_scores = [] total_scores = []
best_non_null_entry = None best_non_null_entry = None
......
from .utils import InputExample, InputFeatures, DataProcessor, SingleSentenceClassificationProcessor # flake8: noqa
from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features # There's no way to ignore "F401 '...' imported but unused" warnings in this
from .squad import squad_convert_examples_to_features, SquadFeatures, SquadExample, SquadV1Processor, SquadV2Processor # module, but to preserve other warnings. So, don't check this module at all.
from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
\ No newline at end of file from .glue import glue_convert_examples_to_features, glue_output_modes, glue_processors, glue_tasks_num_labels
from .squad import SquadExample, SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
from .utils import DataProcessor, InputExample, InputFeatures, SingleSentenceClassificationProcessor
from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
...@@ -18,8 +18,9 @@ ...@@ -18,8 +18,9 @@
import logging import logging
import os import os
from .utils import DataProcessor, InputExample, InputFeatures
from ...file_utils import is_tf_available from ...file_utils import is_tf_available
from .utils import DataProcessor, InputExample, InputFeatures
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
...@@ -27,15 +28,18 @@ if is_tf_available(): ...@@ -27,15 +28,18 @@ if is_tf_available():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def glue_convert_examples_to_features(examples, tokenizer, def glue_convert_examples_to_features(
max_length=512, examples,
task=None, tokenizer,
label_list=None, max_length=512,
output_mode=None, task=None,
pad_on_left=False, label_list=None,
pad_token=0, output_mode=None,
pad_token_segment_id=0, pad_on_left=False,
mask_padding_with_zero=True): pad_token=0,
pad_token_segment_id=0,
mask_padding_with_zero=True,
):
""" """
Loads a data file into a list of ``InputFeatures`` Loads a data file into a list of ``InputFeatures``
...@@ -82,12 +86,7 @@ def glue_convert_examples_to_features(examples, tokenizer, ...@@ -82,12 +86,7 @@ def glue_convert_examples_to_features(examples, tokenizer,
example = processor.get_example_from_tensor_dict(example) example = processor.get_example_from_tensor_dict(example)
example = processor.tfds_map(example) example = processor.tfds_map(example)
inputs = tokenizer.encode_plus( inputs = tokenizer.encode_plus(example.text_a, example.text_b, add_special_tokens=True, max_length=max_length,)
example.text_a,
example.text_b,
add_special_tokens=True,
max_length=max_length,
)
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
# The mask has 1 for real tokens and 0 for padding tokens. Only real # The mask has 1 for real tokens and 0 for padding tokens. Only real
...@@ -106,8 +105,12 @@ def glue_convert_examples_to_features(examples, tokenizer, ...@@ -106,8 +105,12 @@ def glue_convert_examples_to_features(examples, tokenizer,
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length) assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length) assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(
assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids), max_length) len(attention_mask), max_length
)
assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(
len(token_type_ids), max_length
)
if output_mode == "classification": if output_mode == "classification":
label = label_map[example.label] label = label_map[example.label]
...@@ -125,28 +128,36 @@ def glue_convert_examples_to_features(examples, tokenizer, ...@@ -125,28 +128,36 @@ def glue_convert_examples_to_features(examples, tokenizer,
logger.info("label: %s (id = %d)" % (example.label, label)) logger.info("label: %s (id = %d)" % (example.label, label))
features.append( features.append(
InputFeatures(input_ids=input_ids, InputFeatures(
attention_mask=attention_mask, input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, label=label
token_type_ids=token_type_ids, )
label=label)) )
if is_tf_available() and is_tf_dataset: if is_tf_available() and is_tf_dataset:
def gen(): def gen():
for ex in features: for ex in features:
yield ({'input_ids': ex.input_ids, yield (
'attention_mask': ex.attention_mask, {
'token_type_ids': ex.token_type_ids}, "input_ids": ex.input_ids,
ex.label) "attention_mask": ex.attention_mask,
"token_type_ids": ex.token_type_ids,
return tf.data.Dataset.from_generator(gen, },
({'input_ids': tf.int32, ex.label,
'attention_mask': tf.int32, )
'token_type_ids': tf.int32},
tf.int64), return tf.data.Dataset.from_generator(
({'input_ids': tf.TensorShape([None]), gen,
'attention_mask': tf.TensorShape([None]), ({"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, tf.int64),
'token_type_ids': tf.TensorShape([None])}, (
tf.TensorShape([]))) {
"input_ids": tf.TensorShape([None]),
"attention_mask": tf.TensorShape([None]),
"token_type_ids": tf.TensorShape([None]),
},
tf.TensorShape([]),
),
)
return features return features
...@@ -156,21 +167,21 @@ class MrpcProcessor(DataProcessor): ...@@ -156,21 +167,21 @@ class MrpcProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['sentence1'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
tensor_dict['sentence2'].numpy().decode('utf-8'), tensor_dict["sentence1"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -186,8 +197,7 @@ class MrpcProcessor(DataProcessor): ...@@ -186,8 +197,7 @@ class MrpcProcessor(DataProcessor):
text_a = line[3] text_a = line[3]
text_b = line[4] text_b = line[4]
label = line[0] label = line[0]
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -196,21 +206,20 @@ class MnliProcessor(DataProcessor): ...@@ -196,21 +206,20 @@ class MnliProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['premise'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
tensor_dict['hypothesis'].numpy().decode('utf-8'), tensor_dict["premise"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) tensor_dict["hypothesis"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -226,8 +235,7 @@ class MnliProcessor(DataProcessor): ...@@ -226,8 +235,7 @@ class MnliProcessor(DataProcessor):
text_a = line[8] text_a = line[8]
text_b = line[9] text_b = line[9]
label = line[-1] label = line[-1]
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -236,9 +244,7 @@ class MnliMismatchedProcessor(MnliProcessor): ...@@ -236,9 +244,7 @@ class MnliMismatchedProcessor(MnliProcessor):
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_matched")
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
"dev_matched")
class ColaProcessor(DataProcessor): class ColaProcessor(DataProcessor):
...@@ -246,20 +252,20 @@ class ColaProcessor(DataProcessor): ...@@ -246,20 +252,20 @@ class ColaProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['sentence'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
None, tensor_dict["sentence"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) None,
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -272,8 +278,7 @@ class ColaProcessor(DataProcessor): ...@@ -272,8 +278,7 @@ class ColaProcessor(DataProcessor):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
text_a = line[3] text_a = line[3]
label = line[1] label = line[1]
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples return examples
...@@ -282,20 +287,20 @@ class Sst2Processor(DataProcessor): ...@@ -282,20 +287,20 @@ class Sst2Processor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['sentence'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
None, tensor_dict["sentence"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) None,
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -310,8 +315,7 @@ class Sst2Processor(DataProcessor): ...@@ -310,8 +315,7 @@ class Sst2Processor(DataProcessor):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
text_a = line[0] text_a = line[0]
label = line[1] label = line[1]
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples return examples
...@@ -320,20 +324,20 @@ class StsbProcessor(DataProcessor): ...@@ -320,20 +324,20 @@ class StsbProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['sentence1'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
tensor_dict['sentence2'].numpy().decode('utf-8'), tensor_dict["sentence1"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -349,8 +353,7 @@ class StsbProcessor(DataProcessor): ...@@ -349,8 +353,7 @@ class StsbProcessor(DataProcessor):
text_a = line[7] text_a = line[7]
text_b = line[8] text_b = line[8]
label = line[-1] label = line[-1]
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -359,20 +362,20 @@ class QqpProcessor(DataProcessor): ...@@ -359,20 +362,20 @@ class QqpProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['question1'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
tensor_dict['question2'].numpy().decode('utf-8'), tensor_dict["question1"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) tensor_dict["question2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -391,8 +394,7 @@ class QqpProcessor(DataProcessor): ...@@ -391,8 +394,7 @@ class QqpProcessor(DataProcessor):
label = line[5] label = line[5]
except IndexError: except IndexError:
continue continue
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -401,21 +403,20 @@ class QnliProcessor(DataProcessor): ...@@ -401,21 +403,20 @@ class QnliProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['question'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
tensor_dict['sentence'].numpy().decode('utf-8'), tensor_dict["question"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) tensor_dict["sentence"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
self._read_tsv(os.path.join(data_dir, "dev.tsv")),
"dev_matched")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -431,8 +432,7 @@ class QnliProcessor(DataProcessor): ...@@ -431,8 +432,7 @@ class QnliProcessor(DataProcessor):
text_a = line[1] text_a = line[1]
text_b = line[2] text_b = line[2]
label = line[-1] label = line[-1]
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -441,20 +441,20 @@ class RteProcessor(DataProcessor): ...@@ -441,20 +441,20 @@ class RteProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['sentence1'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
tensor_dict['sentence2'].numpy().decode('utf-8'), tensor_dict["sentence1"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -470,8 +470,7 @@ class RteProcessor(DataProcessor): ...@@ -470,8 +470,7 @@ class RteProcessor(DataProcessor):
text_a = line[1] text_a = line[1]
text_b = line[2] text_b = line[2]
label = line[-1] label = line[-1]
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -480,20 +479,20 @@ class WnliProcessor(DataProcessor): ...@@ -480,20 +479,20 @@ class WnliProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['sentence1'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
tensor_dict['sentence2'].numpy().decode('utf-8'), tensor_dict["sentence1"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -509,10 +508,10 @@ class WnliProcessor(DataProcessor): ...@@ -509,10 +508,10 @@ class WnliProcessor(DataProcessor):
text_a = line[1] text_a = line[1]
text_b = line[2] text_b = line[2]
label = line[-1] label = line[-1]
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
glue_tasks_num_labels = { glue_tasks_num_labels = {
"cola": 2, "cola": 2,
"mnli": 3, "mnli": 3,
......
from tqdm import tqdm import json
import collections
import logging import logging
import os import os
import json
import numpy as np
from multiprocessing import Pool
from multiprocessing import cpu_count
from functools import partial from functools import partial
from multiprocessing import Pool, cpu_count
import numpy as np
from tqdm import tqdm
from ...tokenization_bert import BasicTokenizer, whitespace_tokenize
from .utils import DataProcessor, InputExample, InputFeatures
from ...file_utils import is_tf_available, is_torch_available from ...file_utils import is_tf_available, is_torch_available
from ...tokenization_bert import whitespace_tokenize
from .utils import DataProcessor
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -82,8 +82,8 @@ def _is_whitespace(c): ...@@ -82,8 +82,8 @@ def _is_whitespace(c):
return True return True
return False return False
def squad_convert_example_to_features(example, max_seq_length,
doc_stride, max_query_length, is_training): def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_query_length, is_training):
features = [] features = []
if is_training and not example.is_impossible: if is_training and not example.is_impossible:
# Get start and end position # Get start and end position
...@@ -91,7 +91,7 @@ def squad_convert_example_to_features(example, max_seq_length, ...@@ -91,7 +91,7 @@ def squad_convert_example_to_features(example, max_seq_length,
end_position = example.end_position end_position = example.end_position
# If the answer cannot be found in the text, then skip this example. # If the answer cannot be found in the text, then skip this example.
actual_text = " ".join(example.doc_tokens[start_position:(end_position + 1)]) actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)])
cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text)) cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
if actual_text.find(cleaned_answer_text) == -1: if actual_text.find(cleaned_answer_text) == -1:
logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text) logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
...@@ -121,8 +121,11 @@ def squad_convert_example_to_features(example, max_seq_length, ...@@ -121,8 +121,11 @@ def squad_convert_example_to_features(example, max_seq_length,
spans = [] spans = []
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length) truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length)
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence + 1 \ sequence_added_tokens = (
if 'roberta' in str(type(tokenizer)) else tokenizer.max_len - tokenizer.max_len_single_sentence tokenizer.max_len - tokenizer.max_len_single_sentence + 1
if "roberta" in str(type(tokenizer))
else tokenizer.max_len - tokenizer.max_len_single_sentence
)
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
span_doc_tokens = all_doc_tokens span_doc_tokens = all_doc_tokens
...@@ -135,16 +138,18 @@ def squad_convert_example_to_features(example, max_seq_length, ...@@ -135,16 +138,18 @@ def squad_convert_example_to_features(example, max_seq_length,
return_overflowing_tokens=True, return_overflowing_tokens=True,
pad_to_max_length=True, pad_to_max_length=True,
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens, stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
truncation_strategy='only_second' if tokenizer.padding_side == "right" else 'only_first' truncation_strategy="only_second" if tokenizer.padding_side == "right" else "only_first",
) )
paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride, paragraph_len = min(
max_seq_length - len(truncated_query) - sequence_pair_added_tokens) len(all_doc_tokens) - len(spans) * doc_stride,
max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
)
if tokenizer.pad_token_id in encoded_dict['input_ids']: if tokenizer.pad_token_id in encoded_dict["input_ids"]:
non_padded_ids = encoded_dict['input_ids'][:encoded_dict['input_ids'].index(tokenizer.pad_token_id)] non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
else: else:
non_padded_ids = encoded_dict['input_ids'] non_padded_ids = encoded_dict["input_ids"]
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
...@@ -170,17 +175,20 @@ def squad_convert_example_to_features(example, max_seq_length, ...@@ -170,17 +175,20 @@ def squad_convert_example_to_features(example, max_seq_length,
for doc_span_index in range(len(spans)): for doc_span_index in range(len(spans)):
for j in range(spans[doc_span_index]["paragraph_len"]): for j in range(spans[doc_span_index]["paragraph_len"]):
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j) is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
index = j if tokenizer.padding_side == "left" else spans[doc_span_index][ index = (
"truncated_query_with_special_tokens_length"] + j j
if tokenizer.padding_side == "left"
else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
)
spans[doc_span_index]["token_is_max_context"][index] = is_max_context spans[doc_span_index]["token_is_max_context"][index] = is_max_context
for span in spans: for span in spans:
# Identify the position of the CLS token # Identify the position of the CLS token
cls_index = span['input_ids'].index(tokenizer.cls_token_id) cls_index = span["input_ids"].index(tokenizer.cls_token_id)
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# Original TF implem also keep the classification token (set to 0) (not sure why...) # Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask = np.array(span['token_type_ids']) p_mask = np.array(span["token_type_ids"])
p_mask = np.minimum(p_mask, 1) p_mask = np.minimum(p_mask, 1)
...@@ -219,31 +227,34 @@ def squad_convert_example_to_features(example, max_seq_length, ...@@ -219,31 +227,34 @@ def squad_convert_example_to_features(example, max_seq_length,
start_position = tok_start_position - doc_start + doc_offset start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset end_position = tok_end_position - doc_start + doc_offset
features.append(SquadFeatures( features.append(
span['input_ids'], SquadFeatures(
span['attention_mask'], span["input_ids"],
span['token_type_ids'], span["attention_mask"],
cls_index, span["token_type_ids"],
p_mask.tolist(), cls_index,
example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing. p_mask.tolist(),
unique_id=0, example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing.
paragraph_len=span['paragraph_len'], unique_id=0,
token_is_max_context=span["token_is_max_context"], paragraph_len=span["paragraph_len"],
tokens=span["tokens"], token_is_max_context=span["token_is_max_context"],
token_to_orig_map=span["token_to_orig_map"], tokens=span["tokens"],
token_to_orig_map=span["token_to_orig_map"],
start_position=start_position, start_position=start_position,
end_position=end_position end_position=end_position,
)) )
)
return features return features
def squad_convert_example_to_features_init(tokenizer_for_convert): def squad_convert_example_to_features_init(tokenizer_for_convert):
global tokenizer global tokenizer
tokenizer = tokenizer_for_convert tokenizer = tokenizer_for_convert
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
doc_stride, max_query_length, is_training, def squad_convert_examples_to_features(
return_dataset=False, threads=1): examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training, return_dataset=False, threads=1
):
""" """
Converts a list of examples into a list of features that can be directly given as input to a model. Converts a list of examples into a list of features that can be directly given as input to a model.
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs. It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
...@@ -269,7 +280,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -269,7 +280,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
processor = SquadV2Processor() processor = SquadV2Processor()
examples = processor.get_dev_examples(data_dir) examples = processor.get_dev_examples(data_dir)
features = squad_convert_examples_to_features( features = squad_convert_examples_to_features(
examples=examples, examples=examples,
tokenizer=tokenizer, tokenizer=tokenizer,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
...@@ -279,17 +290,28 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -279,17 +290,28 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
) )
""" """
# Defining helper methods # Defining helper methods
features = [] features = []
threads = min(threads, cpu_count()) threads = min(threads, cpu_count())
with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p: with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p:
annotate_ = partial(squad_convert_example_to_features, max_seq_length=max_seq_length, annotate_ = partial(
doc_stride=doc_stride, max_query_length=max_query_length, is_training=is_training) squad_convert_example_to_features,
features = list(tqdm(p.imap(annotate_, examples, chunksize=32), total=len(examples), desc='convert squad examples to features')) max_seq_length=max_seq_length,
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=is_training,
)
features = list(
tqdm(
p.imap(annotate_, examples, chunksize=32),
total=len(examples),
desc="convert squad examples to features",
)
)
new_features = [] new_features = []
unique_id = 1000000000 unique_id = 1000000000
example_index = 0 example_index = 0
for example_features in tqdm(features, total=len(features), desc='add example index and unique id'): for example_features in tqdm(features, total=len(features), desc="add example index and unique id"):
if not example_features: if not example_features:
continue continue
for example_feature in example_features: for example_feature in example_features:
...@@ -300,7 +322,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -300,7 +322,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
example_index += 1 example_index += 1
features = new_features features = new_features
del new_features del new_features
if return_dataset == 'pt': if return_dataset == "pt":
if not is_torch_available(): if not is_torch_available():
raise ImportError("Pytorch must be installed to return a pytorch dataset.") raise ImportError("Pytorch must be installed to return a pytorch dataset.")
...@@ -341,12 +363,13 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -341,12 +363,13 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
"input_ids": ex.input_ids, "input_ids": ex.input_ids,
"attention_mask": ex.attention_mask, "attention_mask": ex.attention_mask,
"token_type_ids": ex.token_type_ids, "token_type_ids": ex.token_type_ids,
}, { },
{
"start_position": ex.start_position, "start_position": ex.start_position,
"end_position": ex.end_position, "end_position": ex.end_position,
"cls_index": ex.cls_index, "cls_index": ex.cls_index,
"p_mask": ex.p_mask, "p_mask": ex.p_mask,
} },
) )
return tf.data.Dataset.from_generator( return tf.data.Dataset.from_generator(
...@@ -616,8 +639,8 @@ class SquadFeatures(object): ...@@ -616,8 +639,8 @@ class SquadFeatures(object):
has more information related to that token and should be prioritized over this feature for that token. has more information related to that token and should be prioritized over this feature for that token.
tokens: list of tokens corresponding to the input ids tokens: list of tokens corresponding to the input ids
token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer. token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer.
start_position: start of the answer token index start_position: start of the answer token index
end_position: end of the answer token index end_position: end of the answer token index
""" """
def __init__( def __init__(
......
...@@ -14,16 +14,18 @@ ...@@ -14,16 +14,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import csv
import sys
import copy import copy
import csv
import json import json
import logging import logging
import sys
from ...file_utils import is_tf_available, is_torch_available from ...file_utils import is_tf_available, is_torch_available
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class InputExample(object): class InputExample(object):
""" """
A single training/test example for simple sequence classification. A single training/test example for simple sequence classification.
...@@ -37,6 +39,7 @@ class InputExample(object): ...@@ -37,6 +39,7 @@ class InputExample(object):
label: (Optional) string. The label of the example. This should be label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples. specified for train and dev examples, but not for test examples.
""" """
def __init__(self, guid, text_a, text_b=None, label=None): def __init__(self, guid, text_a, text_b=None, label=None):
self.guid = guid self.guid = guid
self.text_a = text_a self.text_a = text_a
...@@ -99,14 +102,15 @@ class DataProcessor(object): ...@@ -99,14 +102,15 @@ class DataProcessor(object):
lines = [] lines = []
for line in reader: for line in reader:
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
line = list(unicode(cell, 'utf-8') for cell in line) line = list(unicode(cell, "utf-8") for cell in line) # noqa: F821
lines.append(line) lines.append(line)
return lines return lines
class SingleSentenceClassificationProcessor(DataProcessor): class SingleSentenceClassificationProcessor(DataProcessor):
""" Generic processor for a single sentence classification data set.""" """ Generic processor for a single sentence classification data set."""
def __init__(self, labels=None, examples=None, mode='classification', verbose=False):
def __init__(self, labels=None, examples=None, mode="classification", verbose=False):
self.labels = [] if labels is None else labels self.labels = [] if labels is None else labels
self.examples = [] if examples is None else examples self.examples = [] if examples is None else examples
self.mode = mode self.mode = mode
...@@ -117,22 +121,24 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -117,22 +121,24 @@ class SingleSentenceClassificationProcessor(DataProcessor):
def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(idx, slice): if isinstance(idx, slice):
return SingleSentenceClassificationProcessor(labels=self.labels, return SingleSentenceClassificationProcessor(labels=self.labels, examples=self.examples[idx])
examples=self.examples[idx])
return self.examples[idx] return self.examples[idx]
@classmethod @classmethod
def create_from_csv(cls, file_name, split_name='', column_label=0, column_text=1, def create_from_csv(
column_id=None, skip_first_row=False, **kwargs): cls, file_name, split_name="", column_label=0, column_text=1, column_id=None, skip_first_row=False, **kwargs
):
processor = cls(**kwargs) processor = cls(**kwargs)
processor.add_examples_from_csv(file_name, processor.add_examples_from_csv(
split_name=split_name, file_name,
column_label=column_label, split_name=split_name,
column_text=column_text, column_label=column_label,
column_id=column_id, column_text=column_text,
skip_first_row=skip_first_row, column_id=column_id,
overwrite_labels=True, skip_first_row=skip_first_row,
overwrite_examples=True) overwrite_labels=True,
overwrite_examples=True,
)
return processor return processor
@classmethod @classmethod
...@@ -141,8 +147,17 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -141,8 +147,17 @@ class SingleSentenceClassificationProcessor(DataProcessor):
processor.add_examples(texts_or_text_and_labels, labels=labels) processor.add_examples(texts_or_text_and_labels, labels=labels)
return processor return processor
def add_examples_from_csv(self, file_name, split_name='', column_label=0, column_text=1, column_id=None, def add_examples_from_csv(
skip_first_row=False, overwrite_labels=False, overwrite_examples=False): self,
file_name,
split_name="",
column_label=0,
column_text=1,
column_id=None,
skip_first_row=False,
overwrite_labels=False,
overwrite_examples=False,
):
lines = self._read_tsv(file_name) lines = self._read_tsv(file_name)
if skip_first_row: if skip_first_row:
lines = lines[1:] lines = lines[1:]
...@@ -158,10 +173,13 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -158,10 +173,13 @@ class SingleSentenceClassificationProcessor(DataProcessor):
guid = "%s-%s" % (split_name, i) if split_name else "%s" % i guid = "%s-%s" % (split_name, i) if split_name else "%s" % i
ids.append(guid) ids.append(guid)
return self.add_examples(texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples) return self.add_examples(
texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples
)
def add_examples(self, texts_or_text_and_labels, labels=None, ids=None, def add_examples(
overwrite_labels=False, overwrite_examples=False): self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
):
assert labels is None or len(texts_or_text_and_labels) == len(labels) assert labels is None or len(texts_or_text_and_labels) == len(labels)
assert ids is None or len(texts_or_text_and_labels) == len(ids) assert ids is None or len(texts_or_text_and_labels) == len(ids)
if ids is None: if ids is None:
...@@ -192,13 +210,15 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -192,13 +210,15 @@ class SingleSentenceClassificationProcessor(DataProcessor):
return self.examples return self.examples
def get_features(self, def get_features(
tokenizer, self,
max_length=None, tokenizer,
pad_on_left=False, max_length=None,
pad_token=0, pad_on_left=False,
mask_padding_with_zero=True, pad_token=0,
return_tensors=None): mask_padding_with_zero=True,
return_tensors=None,
):
""" """
Convert examples in a list of ``InputFeatures`` Convert examples in a list of ``InputFeatures``
...@@ -231,9 +251,7 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -231,9 +251,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
logger.info("Tokenizing example %d", ex_index) logger.info("Tokenizing example %d", ex_index)
input_ids = tokenizer.encode( input_ids = tokenizer.encode(
example.text_a, example.text_a, add_special_tokens=True, max_length=min(max_length, tokenizer.max_len),
add_special_tokens=True,
max_length=min(max_length, tokenizer.max_len),
) )
all_input_ids.append(input_ids) all_input_ids.append(input_ids)
...@@ -256,8 +274,12 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -256,8 +274,12 @@ class SingleSentenceClassificationProcessor(DataProcessor):
input_ids = input_ids + ([pad_token] * padding_length) input_ids = input_ids + ([pad_token] * padding_length)
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
assert len(input_ids) == batch_length, "Error with input length {} vs {}".format(len(input_ids), batch_length) assert len(input_ids) == batch_length, "Error with input length {} vs {}".format(
assert len(attention_mask) == batch_length, "Error with input length {} vs {}".format(len(attention_mask), batch_length) len(input_ids), batch_length
)
assert len(attention_mask) == batch_length, "Error with input length {} vs {}".format(
len(attention_mask), batch_length
)
if self.mode == "classification": if self.mode == "classification":
label = label_map[example.label] label = label_map[example.label]
...@@ -273,36 +295,31 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -273,36 +295,31 @@ class SingleSentenceClassificationProcessor(DataProcessor):
logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask])) logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
logger.info("label: %s (id = %d)" % (example.label, label)) logger.info("label: %s (id = %d)" % (example.label, label))
features.append( features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, label=label))
InputFeatures(input_ids=input_ids,
attention_mask=attention_mask,
label=label))
if return_tensors is None: if return_tensors is None:
return features return features
elif return_tensors == 'tf': elif return_tensors == "tf":
if not is_tf_available(): if not is_tf_available():
raise ImportError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported") raise ImportError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
import tensorflow as tf import tensorflow as tf
def gen(): def gen():
for ex in features: for ex in features:
yield ({'input_ids': ex.input_ids, yield ({"input_ids": ex.input_ids, "attention_mask": ex.attention_mask}, ex.label)
'attention_mask': ex.attention_mask},
ex.label) dataset = tf.data.Dataset.from_generator(
gen,
dataset = tf.data.Dataset.from_generator(gen, ({"input_ids": tf.int32, "attention_mask": tf.int32}, tf.int64),
({'input_ids': tf.int32, ({"input_ids": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None])}, tf.TensorShape([])),
'attention_mask': tf.int32}, )
tf.int64),
({'input_ids': tf.TensorShape([None]),
'attention_mask': tf.TensorShape([None])},
tf.TensorShape([])))
return dataset return dataset
elif return_tensors == 'pt': elif return_tensors == "pt":
if not is_torch_available(): if not is_torch_available():
raise ImportError("return_tensors set to 'pt' but PyTorch can't be imported") raise ImportError("return_tensors set to 'pt' but PyTorch can't be imported")
import torch import torch
from torch.utils.data import TensorDataset from torch.utils.data import TensorDataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
if self.mode == "classification": if self.mode == "classification":
......
...@@ -22,13 +22,15 @@ import os ...@@ -22,13 +22,15 @@ import os
from .utils import DataProcessor, InputExample from .utils import DataProcessor, InputExample
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class XnliProcessor(DataProcessor): class XnliProcessor(DataProcessor):
"""Processor for the XNLI dataset. """Processor for the XNLI dataset.
Adapted from https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207""" Adapted from https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207"""
def __init__(self, language, train_language = None): def __init__(self, language, train_language=None):
self.language = language self.language = language
self.train_language = train_language self.train_language = train_language
...@@ -40,13 +42,12 @@ class XnliProcessor(DataProcessor): ...@@ -40,13 +42,12 @@ class XnliProcessor(DataProcessor):
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % ('train', i) guid = "%s-%s" % ("train", i)
text_a = line[0] text_a = line[0]
text_b = line[1] text_b = line[1]
label = "contradiction" if line[2] == "contradictory" else line[2] label = "contradiction" if line[2] == "contradictory" else line[2]
assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str) assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
...@@ -59,19 +60,19 @@ class XnliProcessor(DataProcessor): ...@@ -59,19 +60,19 @@ class XnliProcessor(DataProcessor):
language = line[0] language = line[0]
if language != self.language: if language != self.language:
continue continue
guid = "%s-%s" % ('test', i) guid = "%s-%s" % ("test", i)
text_a = line[6] text_a = line[6]
text_b = line[7] text_b = line[7]
label = line[1] label = line[1]
assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str) assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["contradiction", "entailment", "neutral"] return ["contradiction", "entailment", "neutral"]
xnli_processors = { xnli_processors = {
"xnli": XnliProcessor, "xnli": XnliProcessor,
} }
......
...@@ -3,35 +3,37 @@ Utilities for working with the local dataset cache. ...@@ -3,35 +3,37 @@ Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors. Copyright by the AllenNLP authors.
""" """
from __future__ import (absolute_import, division, print_function, unicode_literals) from __future__ import absolute_import, division, print_function, unicode_literals
import sys import fnmatch
import json import json
import logging import logging
import os import os
import six import sys
import tempfile import tempfile
import fnmatch from contextlib import contextmanager
from functools import partial, wraps from functools import partial, wraps
from hashlib import sha256 from hashlib import sha256
from io import open from io import open
import boto3 import boto3
import requests
import six
from botocore.config import Config from botocore.config import Config
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
import requests from filelock import FileLock
from tqdm.auto import tqdm from tqdm.auto import tqdm
from contextlib import contextmanager
from . import __version__ from . import __version__
from filelock import FileLock
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
try: try:
os.environ.setdefault('USE_TORCH', 'YES') os.environ.setdefault("USE_TORCH", "YES")
if os.environ['USE_TORCH'].upper() in ('1', 'ON', 'YES'): if os.environ["USE_TORCH"].upper() in ("1", "ON", "YES"):
import torch import torch
_torch_available = True # pylint: disable=invalid-name _torch_available = True # pylint: disable=invalid-name
logger.info("PyTorch version {} available.".format(torch.__version__)) logger.info("PyTorch version {} available.".format(torch.__version__))
else: else:
...@@ -41,10 +43,11 @@ except ImportError: ...@@ -41,10 +43,11 @@ except ImportError:
_torch_available = False # pylint: disable=invalid-name _torch_available = False # pylint: disable=invalid-name
try: try:
os.environ.setdefault('USE_TF', 'YES') os.environ.setdefault("USE_TF", "YES")
if os.environ['USE_TF'].upper() in ('1', 'ON', 'YES'): if os.environ["USE_TF"].upper() in ("1", "ON", "YES"):
import tensorflow as tf import tensorflow as tf
assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name _tf_available = True # pylint: disable=invalid-name
logger.info("TensorFlow version {} available.".format(tf.__version__)) logger.info("TensorFlow version {} available.".format(tf.__version__))
else: else:
...@@ -55,12 +58,13 @@ except (ImportError, AssertionError): ...@@ -55,12 +58,13 @@ except (ImportError, AssertionError):
try: try:
from torch.hub import _get_torch_home from torch.hub import _get_torch_home
torch_cache_home = _get_torch_home() torch_cache_home = _get_torch_home()
except ImportError: except ImportError:
torch_cache_home = os.path.expanduser( torch_cache_home = os.path.expanduser(
os.getenv('TORCH_HOME', os.path.join( os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) )
default_cache_path = os.path.join(torch_cache_home, 'transformers') default_cache_path = os.path.join(torch_cache_home, "transformers")
try: try:
from urllib.parse import urlparse from urllib.parse import urlparse
...@@ -69,19 +73,21 @@ except ImportError: ...@@ -69,19 +73,21 @@ except ImportError:
try: try:
from pathlib import Path from pathlib import Path
PYTORCH_PRETRAINED_BERT_CACHE = Path( PYTORCH_PRETRAINED_BERT_CACHE = Path(
os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))) os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path))
)
except (AttributeError, ImportError): except (AttributeError, ImportError):
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE', PYTORCH_PRETRAINED_BERT_CACHE = os.getenv(
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', "PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
default_cache_path)) )
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_NAME = "pytorch_model.bin"
TF2_WEIGHTS_NAME = 'tf_model.h5' TF2_WEIGHTS_NAME = "tf_model.h5"
TF_WEIGHTS_NAME = 'model.ckpt' TF_WEIGHTS_NAME = "model.ckpt"
CONFIG_NAME = "config.json" CONFIG_NAME = "config.json"
MODEL_CARD_NAME = "modelcard.json" MODEL_CARD_NAME = "modelcard.json"
...@@ -95,38 +101,48 @@ CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net" ...@@ -95,38 +101,48 @@ CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net"
def is_torch_available(): def is_torch_available():
return _torch_available return _torch_available
def is_tf_available(): def is_tf_available():
return _tf_available return _tf_available
if not six.PY2: if not six.PY2:
def add_start_docstrings(*docstr): def add_start_docstrings(*docstr):
def docstring_decorator(fn): def docstring_decorator(fn):
fn.__doc__ = ''.join(docstr) + fn.__doc__ fn.__doc__ = "".join(docstr) + fn.__doc__
return fn return fn
return docstring_decorator return docstring_decorator
def add_end_docstrings(*docstr): def add_end_docstrings(*docstr):
def docstring_decorator(fn): def docstring_decorator(fn):
fn.__doc__ = fn.__doc__ + ''.join(docstr) fn.__doc__ = fn.__doc__ + "".join(docstr)
return fn return fn
return docstring_decorator return docstring_decorator
else: else:
# Not possible to update class docstrings on python2 # Not possible to update class docstrings on python2
def add_start_docstrings(*docstr): def add_start_docstrings(*docstr):
def docstring_decorator(fn): def docstring_decorator(fn):
return fn return fn
return docstring_decorator return docstring_decorator
def add_end_docstrings(*docstr): def add_end_docstrings(*docstr):
def docstring_decorator(fn): def docstring_decorator(fn):
return fn return fn
return docstring_decorator return docstring_decorator
def is_remote_url(url_or_filename): def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename) parsed = urlparse(url_or_filename)
return parsed.scheme in ('http', 'https', 's3') return parsed.scheme in ("http", "https", "s3")
def hf_bucket_url(identifier, postfix=None, cdn=False): def hf_bucket_url(identifier, postfix=None, cdn=False):
endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
...@@ -145,17 +161,17 @@ def url_to_filename(url, etag=None): ...@@ -145,17 +161,17 @@ def url_to_filename(url, etag=None):
so that TF 2.0 can identify it as a HDF5 file so that TF 2.0 can identify it as a HDF5 file
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
""" """
url_bytes = url.encode('utf-8') url_bytes = url.encode("utf-8")
url_hash = sha256(url_bytes) url_hash = sha256(url_bytes)
filename = url_hash.hexdigest() filename = url_hash.hexdigest()
if etag: if etag:
etag_bytes = etag.encode('utf-8') etag_bytes = etag.encode("utf-8")
etag_hash = sha256(etag_bytes) etag_hash = sha256(etag_bytes)
filename += '.' + etag_hash.hexdigest() filename += "." + etag_hash.hexdigest()
if url.endswith('.h5'): if url.endswith(".h5"):
filename += '.h5' filename += ".h5"
return filename return filename
...@@ -174,19 +190,21 @@ def filename_to_url(filename, cache_dir=None): ...@@ -174,19 +190,21 @@ def filename_to_url(filename, cache_dir=None):
if not os.path.exists(cache_path): if not os.path.exists(cache_path):
raise EnvironmentError("file {} not found".format(cache_path)) raise EnvironmentError("file {} not found".format(cache_path))
meta_path = cache_path + '.json' meta_path = cache_path + ".json"
if not os.path.exists(meta_path): if not os.path.exists(meta_path):
raise EnvironmentError("file {} not found".format(meta_path)) raise EnvironmentError("file {} not found".format(meta_path))
with open(meta_path, encoding="utf-8") as meta_file: with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file) metadata = json.load(meta_file)
url = metadata['url'] url = metadata["url"]
etag = metadata['etag'] etag = metadata["etag"]
return url, etag return url, etag
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None): def cached_path(
url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None
):
""" """
Given something that might be a URL (or might be a local path), Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and determine which. If it's a URL, download the file and cache it, and
...@@ -207,13 +225,18 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N ...@@ -207,13 +225,18 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
if is_remote_url(url_or_filename): if is_remote_url(url_or_filename):
# URL, so get it from the cache (downloading if necessary) # URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir=cache_dir, return get_from_cache(
force_download=force_download, proxies=proxies, url_or_filename,
resume_download=resume_download, user_agent=user_agent) cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
user_agent=user_agent,
)
elif os.path.exists(url_or_filename): elif os.path.exists(url_or_filename):
# File, and it exists. # File, and it exists.
return url_or_filename return url_or_filename
elif urlparse(url_or_filename).scheme == '': elif urlparse(url_or_filename).scheme == "":
# File, but it doesn't exist. # File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename)) raise EnvironmentError("file {} not found".format(url_or_filename))
else: else:
...@@ -273,31 +296,35 @@ def s3_get(url, temp_file, proxies=None): ...@@ -273,31 +296,35 @@ def s3_get(url, temp_file, proxies=None):
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None): def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0]) ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
if isinstance(user_agent, dict): if isinstance(user_agent, dict):
ua += "; " + "; ".join( ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
"{}/{}".format(k, v) for k, v in user_agent.items()
)
elif isinstance(user_agent, six.string_types): elif isinstance(user_agent, six.string_types):
ua += "; "+ user_agent ua += "; " + user_agent
headers = { headers = {"user-agent": ua}
"user-agent": ua
}
if resume_size > 0: if resume_size > 0:
headers['Range'] = 'bytes=%d-' % (resume_size,) headers["Range"] = "bytes=%d-" % (resume_size,)
response = requests.get(url, stream=True, proxies=proxies, headers=headers) response = requests.get(url, stream=True, proxies=proxies, headers=headers)
if response.status_code == 416: # Range not satisfiable if response.status_code == 416: # Range not satisfiable
return return
content_length = response.headers.get('Content-Length') content_length = response.headers.get("Content-Length")
total = resume_size + int(content_length) if content_length is not None else None total = resume_size + int(content_length) if content_length is not None else None
progress = tqdm(unit="B", unit_scale=True, total=total, initial=resume_size, progress = tqdm(
desc="Downloading", disable=bool(logger.level<=logging.INFO)) unit="B",
unit_scale=True,
total=total,
initial=resume_size,
desc="Downloading",
disable=bool(logger.level <= logging.INFO),
)
for chunk in response.iter_content(chunk_size=1024): for chunk in response.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks
progress.update(len(chunk)) progress.update(len(chunk))
temp_file.write(chunk) temp_file.write(chunk)
progress.close() progress.close()
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None): def get_from_cache(
url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None
):
""" """
Given a URL, look for the corresponding dataset in the local cache. Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file. If it's not there, download it. Then return the path to the cached file.
...@@ -326,7 +353,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag ...@@ -326,7 +353,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
etag = None etag = None
if sys.version_info[0] == 2 and etag is not None: if sys.version_info[0] == 2 and etag is not None:
etag = etag.decode('utf-8') etag = etag.decode("utf-8")
filename = url_to_filename(url, etag) filename = url_to_filename(url, etag)
# get cache path to put the file # get cache path to put the file
...@@ -337,22 +364,24 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag ...@@ -337,22 +364,24 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
if not os.path.exists(cache_path) and etag is None: if not os.path.exists(cache_path) and etag is None:
matching_files = [ matching_files = [
file file
for file in fnmatch.filter(os.listdir(cache_dir), filename + '.*') for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
if not file.endswith('.json') and not file.endswith('.lock') if not file.endswith(".json") and not file.endswith(".lock")
] ]
if matching_files: if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1]) cache_path = os.path.join(cache_dir, matching_files[-1])
# Prevent parallel downloads of the same file with a lock. # Prevent parallel downloads of the same file with a lock.
lock_path = cache_path + '.lock' lock_path = cache_path + ".lock"
with FileLock(lock_path): with FileLock(lock_path):
if resume_download: if resume_download:
incomplete_path = cache_path + '.incomplete' incomplete_path = cache_path + ".incomplete"
@contextmanager @contextmanager
def _resumable_file_manager(): def _resumable_file_manager():
with open(incomplete_path,'a+b') as f: with open(incomplete_path, "a+b") as f:
yield f yield f
temp_file_manager = _resumable_file_manager temp_file_manager = _resumable_file_manager
if os.path.exists(incomplete_path): if os.path.exists(incomplete_path):
resume_size = os.stat(incomplete_path).st_size resume_size = os.stat(incomplete_path).st_size
...@@ -366,7 +395,9 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag ...@@ -366,7 +395,9 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
# Download to temporary file, then copy to cache dir once finished. # Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted. # Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file: with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) logger.info(
"%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name
)
# GET file object # GET file object
if url.startswith("s3://"): if url.startswith("s3://"):
...@@ -383,12 +414,12 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag ...@@ -383,12 +414,12 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
os.rename(temp_file.name, cache_path) os.rename(temp_file.name, cache_path)
logger.info("creating metadata file for %s", cache_path) logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag} meta = {"url": url, "etag": etag}
meta_path = cache_path + '.json' meta_path = cache_path + ".json"
with open(meta_path, 'w') as meta_file: with open(meta_path, "w") as meta_file:
output_string = json.dumps(meta) output_string = json.dumps(meta)
if sys.version_info[0] == 2 and isinstance(output_string, str): if sys.version_info[0] == 2 and isinstance(output_string, str):
output_string = unicode(output_string, 'utf-8') # The beauty of python 2 output_string = unicode(output_string, "utf-8") # noqa: F821
meta_file.write(output_string) meta_file.write(output_string)
return cache_path return cache_path
...@@ -14,23 +14,26 @@ ...@@ -14,23 +14,26 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import io
import os import os
from os.path import expanduser from os.path import expanduser
from typing import List
import requests import requests
import six import six
from requests.exceptions import HTTPError
from tqdm import tqdm from tqdm import tqdm
ENDPOINT = "https://huggingface.co" ENDPOINT = "https://huggingface.co"
class S3Obj: class S3Obj:
def __init__( def __init__(
self, self,
filename, # type: str filename, # type: str
LastModified, # type: str LastModified, # type: str
ETag, # type: str ETag, # type: str
Size, # type: int Size, # type: int
**kwargs **kwargs
): ):
self.filename = filename self.filename = filename
...@@ -43,13 +46,13 @@ class PresignedUrl: ...@@ -43,13 +46,13 @@ class PresignedUrl:
def __init__( def __init__(
self, self,
write, # type: str write, # type: str
access, # type: str access, # type: str
type, # type: str type, # type: str
**kwargs **kwargs
): ):
self.write = write self.write = write
self.access = access self.access = access
self.type = type # mime-type to send to S3. self.type = type # mime-type to send to S3.
class HfApi: class HfApi:
...@@ -58,8 +61,8 @@ class HfApi: ...@@ -58,8 +61,8 @@ class HfApi:
def login( def login(
self, self,
username, # type: str username, # type: str
password, # type: str password, # type: str
): ):
# type: (...) -> str # type: (...) -> str
""" """
...@@ -78,8 +81,7 @@ class HfApi: ...@@ -78,8 +81,7 @@ class HfApi:
return d["token"] return d["token"]
def whoami( def whoami(
self, self, token, # type: str
token, # type: str
): ):
# type: (...) -> str # type: (...) -> str
""" """
...@@ -92,7 +94,7 @@ class HfApi: ...@@ -92,7 +94,7 @@ class HfApi:
return d["user"] return d["user"]
def logout(self, token): def logout(self, token):
# type: (...) -> void # type: (...) -> None
""" """
Call HF API to log out. Call HF API to log out.
""" """
...@@ -106,11 +108,7 @@ class HfApi: ...@@ -106,11 +108,7 @@ class HfApi:
Call HF API to get a presigned url to upload `filename` to S3. Call HF API to get a presigned url to upload `filename` to S3.
""" """
path = "{}/api/presign".format(self.endpoint) path = "{}/api/presign".format(self.endpoint)
r = requests.post( r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename},)
path,
headers={"authorization": "Bearer {}".format(token)},
json={"filename": filename},
)
r.raise_for_status() r.raise_for_status()
d = r.json() d = r.json()
return PresignedUrl(**d) return PresignedUrl(**d)
...@@ -126,22 +124,19 @@ class HfApi: ...@@ -126,22 +124,19 @@ class HfApi:
urls = self.presign(token, filename=filename) urls = self.presign(token, filename=filename)
# streaming upload: # streaming upload:
# https://2.python-requests.org/en/master/user/advanced/#streaming-uploads # https://2.python-requests.org/en/master/user/advanced/#streaming-uploads
# #
# Even though we presign with the correct content-type, # Even though we presign with the correct content-type,
# the client still has to specify it when uploading the file. # the client still has to specify it when uploading the file.
with open(filepath, "rb") as f: with open(filepath, "rb") as f:
pf = TqdmProgressFileReader(f) pf = TqdmProgressFileReader(f)
data = f if pf.total_size > 0 else "" data = f if pf.total_size > 0 else ""
r = requests.put(urls.write, data=data, headers={ r = requests.put(urls.write, data=data, headers={"content-type": urls.type})
"content-type": urls.type,
})
r.raise_for_status() r.raise_for_status()
pf.close() pf.close()
return urls.access return urls.access
def list_objs(self, token): def list_objs(self, token) -> List[S3Obj]:
# type: (...) -> List[S3Obj]
""" """
Call HF API to list all stored files for user. Call HF API to list all stored files for user.
""" """
...@@ -152,7 +147,6 @@ class HfApi: ...@@ -152,7 +147,6 @@ class HfApi:
return [S3Obj(**x) for x in d] return [S3Obj(**x) for x in d]
class TqdmProgressFileReader: class TqdmProgressFileReader:
""" """
Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`) Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`)
...@@ -161,12 +155,10 @@ class TqdmProgressFileReader: ...@@ -161,12 +155,10 @@ class TqdmProgressFileReader:
see github.com/huggingface/transformers/pull/2078#discussion_r354739608 see github.com/huggingface/transformers/pull/2078#discussion_r354739608
for implementation details. for implementation details.
""" """
def __init__(
self, def __init__(self, f: io.BufferedReader):
f # type: io.BufferedReader
):
self.f = f self.f = f
self.total_size = os.fstat(f.fileno()).st_size # type: int self.total_size = os.fstat(f.fileno()).st_size # type: int
self.pbar = tqdm(total=self.total_size, leave=False) self.pbar = tqdm(total=self.total_size, leave=False)
if six.PY3: if six.PY3:
# does not work unless PY3 # does not work unless PY3
...@@ -182,7 +174,6 @@ class TqdmProgressFileReader: ...@@ -182,7 +174,6 @@ class TqdmProgressFileReader:
self.pbar.close() self.pbar.close()
class HfFolder: class HfFolder:
path_token = expanduser("~/.huggingface/token") path_token = expanduser("~/.huggingface/token")
...@@ -201,7 +192,7 @@ class HfFolder: ...@@ -201,7 +192,7 @@ class HfFolder:
if e.errno != os.errno.EEXIST: if e.errno != os.errno.EEXIST:
raise e raise e
pass pass
with open(cls.path_token, 'w+') as f: with open(cls.path_token, "w+") as f:
f.write(token) f.write(token)
@classmethod @classmethod
...@@ -210,12 +201,10 @@ class HfFolder: ...@@ -210,12 +201,10 @@ class HfFolder:
Get token or None if not existent. Get token or None if not existent.
""" """
try: try:
with open(cls.path_token, 'r') as f: with open(cls.path_token, "r") as f:
return f.read() return f.read()
except: except FileNotFoundError:
# this is too wide. When Py2 is dead use: pass
# `except FileNotFoundError:` instead
return None
@classmethod @classmethod
def delete_token(cls): def delete_token(cls):
...@@ -225,5 +214,5 @@ class HfFolder: ...@@ -225,5 +214,5 @@ class HfFolder:
""" """
try: try:
os.remove(cls.path_token) os.remove(cls.path_token)
except: except FileNotFoundError:
return pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment