"...docs/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "aa59fca5d78e9d4652c0c9c837dc6c465cde0787"
Commit 1b35d05d authored by thomwolf's avatar thomwolf
Browse files

update conversion scripts and __main__

parent 352e3ff9
# coding: utf8 # coding: utf8
def main(): def main():
import sys import sys
if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet"]: if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet", "xlm"]:
print( print(
"Should be used as one of: \n" "Should be used as one of: \n"
">> `pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" ">> pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n"
">> `pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" ">> pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n"
">> `pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" ">> pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n"
">> `pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]` or \n" ">> pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n"
">> `pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`") ">> pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME] or \n"
">> pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT")
else: else:
if sys.argv[1] == "bert": if sys.argv[1] == "bert":
try: try:
...@@ -86,7 +87,7 @@ def main(): ...@@ -86,7 +87,7 @@ def main():
else: else:
TF_CONFIG = "" TF_CONFIG = ""
convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
else: elif sys.argv[1] == "xlnet":
try: try:
from .convert_xlnet_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch from .convert_xlnet_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch
except ImportError: except ImportError:
...@@ -104,11 +105,24 @@ def main(): ...@@ -104,11 +105,24 @@ def main():
PYTORCH_DUMP_OUTPUT = sys.argv[4] PYTORCH_DUMP_OUTPUT = sys.argv[4]
if len(sys.argv) == 6: if len(sys.argv) == 6:
FINETUNING_TASK = sys.argv[5] FINETUNING_TASK = sys.argv[5]
else:
FINETUNING_TASK = None
convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT, convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT,
TF_CONFIG, TF_CONFIG,
PYTORCH_DUMP_OUTPUT, PYTORCH_DUMP_OUTPUT,
FINETUNING_TASK) FINETUNING_TASK)
elif sys.argv[1] == "xlm":
from .convert_xlm_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch
if len(sys.argv) != 4:
# pylint: disable=line-too-long
print("Should be used as `pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT`")
else:
XLM_CHECKPOINT_PATH = sys.argv[2]
PYTORCH_DUMP_OUTPUT = sys.argv[3]
convert_xlm_checkpoint_to_pytorch(XLM_CHECKPOINT_PATH, PYTORCH_DUMP_OUTPUT)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -26,6 +26,9 @@ from pytorch_transformers.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, ...@@ -26,6 +26,9 @@ from pytorch_transformers.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME,
GPT2Model, GPT2Model,
load_tf_weights_in_gpt2) load_tf_weights_in_gpt2)
import logging
logging.basicConfig(level=logging.INFO)
def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
# Construct model # Construct model
...@@ -36,7 +39,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p ...@@ -36,7 +39,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
model = GPT2Model(config) model = GPT2Model(config)
# Load weights from numpy # Load weights from numpy
load_tf_weights_in_gpt2(model, gpt2_checkpoint_path) load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)
# Save pytorch-model # Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
......
...@@ -26,6 +26,9 @@ from pytorch_transformers.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME, ...@@ -26,6 +26,9 @@ from pytorch_transformers.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME,
OpenAIGPTModel, OpenAIGPTModel,
load_tf_weights_in_openai_gpt) load_tf_weights_in_openai_gpt)
import logging
logging.basicConfig(level=logging.INFO)
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
# Construct model # Construct model
...@@ -36,7 +39,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c ...@@ -36,7 +39,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
model = OpenAIGPTModel(config) model = OpenAIGPTModel(config)
# Load weights from numpy # Load weights from numpy
load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path) load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path)
# Save pytorch-model # Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
......
...@@ -18,15 +18,14 @@ from __future__ import absolute_import ...@@ -18,15 +18,14 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import re
import argparse import argparse
import tensorflow as tf
import torch import torch
import numpy as np
from pytorch_transformers.modeling_bert import BertConfig, BertForPreTraining, load_tf_weights_in_bert from pytorch_transformers.modeling_bert import BertConfig, BertForPreTraining, load_tf_weights_in_bert
import logging
logging.basicConfig(level=logging.INFO)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
# Initialise PyTorch model # Initialise PyTorch model
config = BertConfig.from_json_file(bert_config_file) config = BertConfig.from_json_file(bert_config_file)
...@@ -34,7 +33,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor ...@@ -34,7 +33,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
model = BertForPreTraining(config) model = BertForPreTraining(config)
# Load weights from tf checkpoint # Load weights from tf checkpoint
load_tf_weights_in_bert(model, tf_checkpoint_path) load_tf_weights_in_bert(model, config, tf_checkpoint_path)
# Save pytorch-model # Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path)) print("Save PyTorch model to {}".format(pytorch_dump_path))
......
...@@ -36,6 +36,9 @@ if sys.version_info[0] == 2: ...@@ -36,6 +36,9 @@ if sys.version_info[0] == 2:
else: else:
import pickle import pickle
import logging
logging.basicConfig(level=logging.INFO)
# We do this to be able to load python 2 datasets pickles # We do this to be able to load python 2 datasets pickles
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
data_utils.Vocab = data_utils.TransfoXLTokenizer data_utils.Vocab = data_utils.TransfoXLTokenizer
......
...@@ -24,9 +24,10 @@ import torch ...@@ -24,9 +24,10 @@ import torch
import numpy import numpy
from pytorch_transformers.modeling_utils import CONFIG_NAME, WEIGHTS_NAME from pytorch_transformers.modeling_utils import CONFIG_NAME, WEIGHTS_NAME
from pytorch_transformers.modeling_xlm import (XLMConfig, XLMModel)
from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES
import logging
logging.basicConfig(level=logging.INFO)
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
# Load checkpoint # Load checkpoint
......
...@@ -40,6 +40,8 @@ GLUE_TASKS_NUM_LABELS = { ...@@ -40,6 +40,8 @@ GLUE_TASKS_NUM_LABELS = {
"wnli": 2, "wnli": 2,
} }
import logging
logging.basicConfig(level=logging.INFO)
def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None): def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None):
# Initialise PyTorch model # Initialise PyTorch model
...@@ -48,14 +50,17 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py ...@@ -48,14 +50,17 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
finetuning_task = finetuning_task.lower() if finetuning_task is not None else "" finetuning_task = finetuning_task.lower() if finetuning_task is not None else ""
if finetuning_task in GLUE_TASKS_NUM_LABELS: if finetuning_task in GLUE_TASKS_NUM_LABELS:
print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config))) print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config)))
model = XLNetForSequenceClassification(config, num_labels=GLUE_TASKS_NUM_LABELS[finetuning_task]) config.finetuning_task = finetuning_task
config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task]
model = XLNetForSequenceClassification(config)
elif 'squad' in finetuning_task: elif 'squad' in finetuning_task:
config.finetuning_task = finetuning_task
model = XLNetForQuestionAnswering(config) model = XLNetForQuestionAnswering(config)
else: else:
model = XLNetLMHeadModel(config) model = XLNetLMHeadModel(config)
# Load weights from tf checkpoint # Load weights from tf checkpoint
load_tf_weights_in_xlnet(model, config, tf_checkpoint_path, finetuning_task) load_tf_weights_in_xlnet(model, config, tf_checkpoint_path)
# Save pytorch-model # Save pytorch-model
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
......
...@@ -37,9 +37,11 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra ...@@ -37,9 +37,11 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = { XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-pytorch_model.bin",
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin", 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin",
} }
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = { XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json",
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json", 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
} }
......
...@@ -50,7 +50,7 @@ PRETRAINED_VOCAB_FILES_MAP = { ...@@ -50,7 +50,7 @@ PRETRAINED_VOCAB_FILES_MAP = {
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'transfo-xl-wt103': 512, 'transfo-xl-wt103': None,
} }
PRETRAINED_CORPUS_ARCHIVE_MAP = { PRETRAINED_CORPUS_ARCHIVE_MAP = {
......
...@@ -208,7 +208,8 @@ class PreTrainedTokenizer(object): ...@@ -208,7 +208,8 @@ class PreTrainedTokenizer(object):
# if we're using a pretrained model, ensure the tokenizer # if we're using a pretrained model, ensure the tokenizer
# wont index sequences longer than the number of positional embeddings # wont index sequences longer than the number of positional embeddings
max_len = cls.max_model_input_sizes[pretrained_model_name_or_path] max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) if max_len is not None and isinstance(max_len, (int, float)):
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Merge resolved_vocab_files arguments in kwargs. # Merge resolved_vocab_files arguments in kwargs.
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None) added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
......
...@@ -32,12 +32,14 @@ VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'} ...@@ -32,12 +32,14 @@ VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': 'vocab_file':
{ {
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model",
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model", 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model",
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlnet-large-cased': 512, 'xlnet-base-cased': None,
'xlnet-large-cased': None,
} }
SPIECE_UNDERLINE = u'▁' SPIECE_UNDERLINE = u'▁'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment