Commit 33dd59e9 authored by thomwolf's avatar thomwolf
Browse files

update conversion script names

parent 5951d860
...@@ -3,7 +3,8 @@ def main(): ...@@ -3,7 +3,8 @@ def main():
import sys import sys
if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet", "xlm"]: if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet", "xlm"]:
print( print(
"Should be used as one of: \n" "This command line utility let you convert original (author released) model checkpoint to pytorch.\n"
"It should be used as one of: \n"
">> pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n" ">> pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n"
">> pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n" ">> pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n"
">> pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n" ">> pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n"
...@@ -13,7 +14,7 @@ def main(): ...@@ -13,7 +14,7 @@ def main():
else: else:
if sys.argv[1] == "bert": if sys.argv[1] == "bert":
try: try:
from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch from .convert_bert_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
except ImportError: except ImportError:
print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see " "In that case, it requires TensorFlow to be installed. Please see "
...@@ -29,7 +30,7 @@ def main(): ...@@ -29,7 +30,7 @@ def main():
TF_CHECKPOINT = sys.argv.pop() TF_CHECKPOINT = sys.argv.pop()
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
elif sys.argv[1] == "gpt": elif sys.argv[1] == "gpt":
from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch from .convert_openai_original_tf_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
if len(sys.argv) < 4 or len(sys.argv) > 5: if len(sys.argv) < 4 or len(sys.argv) > 5:
# pylint: disable=line-too-long # pylint: disable=line-too-long
print("Should be used as `pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`") print("Should be used as `pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`")
...@@ -45,7 +46,7 @@ def main(): ...@@ -45,7 +46,7 @@ def main():
PYTORCH_DUMP_OUTPUT) PYTORCH_DUMP_OUTPUT)
elif sys.argv[1] == "transfo_xl": elif sys.argv[1] == "transfo_xl":
try: try:
from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch from .convert_transfo_xl_original_tf_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
except ImportError: except ImportError:
print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see " "In that case, it requires TensorFlow to be installed. Please see "
...@@ -69,7 +70,7 @@ def main(): ...@@ -69,7 +70,7 @@ def main():
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
elif sys.argv[1] == "gpt2": elif sys.argv[1] == "gpt2":
try: try:
from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch from .convert_gpt2_original_tf_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
except ImportError: except ImportError:
print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see " "In that case, it requires TensorFlow to be installed. Please see "
...@@ -89,7 +90,7 @@ def main(): ...@@ -89,7 +90,7 @@ def main():
convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
elif sys.argv[1] == "xlnet": elif sys.argv[1] == "xlnet":
try: try:
from .convert_xlnet_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch from .convert_xlnet_original_tf_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch
except ImportError: except ImportError:
print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see " "In that case, it requires TensorFlow to be installed. Please see "
...@@ -113,7 +114,7 @@ def main(): ...@@ -113,7 +114,7 @@ def main():
PYTORCH_DUMP_OUTPUT, PYTORCH_DUMP_OUTPUT,
FINETUNING_TASK) FINETUNING_TASK)
elif sys.argv[1] == "xlm": elif sys.argv[1] == "xlm":
from .convert_xlm_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch from .convert_xlm_original_pytorch_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch
if len(sys.argv) != 4: if len(sys.argv) != 4:
# pylint: disable=line-too-long # pylint: disable=line-too-long
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment