__main__.py 6.86 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding: utf8
2
def main():
thomwolf's avatar
thomwolf committed
3
    import sys
4
    if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet", "xlm"]:
thomwolf's avatar
thomwolf committed
5
        print(
thomwolf's avatar
thomwolf committed
6
        "Should be used as one of: \n"
7
8
9
10
11
12
        ">> pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n"
        ">> pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n"
        ">> pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n"
        ">> pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n"
        ">> pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME] or \n"
        ">> pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT")
thomwolf's avatar
thomwolf committed
13
    else:
14
        if sys.argv[1] == "bert":
thomwolf's avatar
thomwolf committed
15
            try:
thomwolf's avatar
thomwolf committed
16
                from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
thomwolf's avatar
thomwolf committed
17
            except ImportError:
thomwolf's avatar
thomwolf committed
18
                print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
thomwolf's avatar
thomwolf committed
19
20
21
22
23
24
                    "In that case, it requires TensorFlow to be installed. Please see "
                    "https://www.tensorflow.org/install/ for installation instructions.")
                raise

            if len(sys.argv) != 5:
                # pylint: disable=line-too-long
thomwolf's avatar
thomwolf committed
25
                print("Should be used as `pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
thomwolf's avatar
thomwolf committed
26
27
28
29
30
            else:
                PYTORCH_DUMP_OUTPUT = sys.argv.pop()
                TF_CONFIG = sys.argv.pop()
                TF_CHECKPOINT = sys.argv.pop()
                convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
31
        elif sys.argv[1] == "gpt":
thomwolf's avatar
thomwolf committed
32
            from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
33
34
            if len(sys.argv) < 4 or len(sys.argv) > 5:
                # pylint: disable=line-too-long
thomwolf's avatar
thomwolf committed
35
                print("Should be used as `pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`")
thomwolf's avatar
thomwolf committed
36
            else:
37
38
39
40
41
42
43
44
45
46
                OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
                PYTORCH_DUMP_OUTPUT = sys.argv[3]
                if len(sys.argv) == 5:
                    OPENAI_GPT_CONFIG = sys.argv[4]
                else:
                    OPENAI_GPT_CONFIG = ""
                convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
                                                    OPENAI_GPT_CONFIG,
                                                    PYTORCH_DUMP_OUTPUT)
        elif sys.argv[1] == "transfo_xl":
thomwolf's avatar
thomwolf committed
47
            try:
thomwolf's avatar
thomwolf committed
48
                from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
thomwolf's avatar
thomwolf committed
49
            except ImportError:
thomwolf's avatar
thomwolf committed
50
                print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
thomwolf's avatar
thomwolf committed
51
52
53
                    "In that case, it requires TensorFlow to be installed. Please see "
                    "https://www.tensorflow.org/install/ for installation instructions.")
                raise
54
55
            if len(sys.argv) < 4 or len(sys.argv) > 5:
                # pylint: disable=line-too-long
thomwolf's avatar
thomwolf committed
56
                print("Should be used as `pytorch_transformers transfo_xl TF_CHECKPOINT/TF_DATASET_FILE PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
thomwolf's avatar
thomwolf committed
57
            else:
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
                if 'ckpt' in sys.argv[2].lower():
                    TF_CHECKPOINT = sys.argv[2]
                    TF_DATASET_FILE = ""
                else:
                    TF_DATASET_FILE = sys.argv[2]
                    TF_CHECKPOINT = ""
                PYTORCH_DUMP_OUTPUT = sys.argv[3]
                if len(sys.argv) == 5:
                    TF_CONFIG = sys.argv[4]
                else:
                    TF_CONFIG = ""
                convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
        elif sys.argv[1] == "gpt2":
            try:
                from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
            except ImportError:
thomwolf's avatar
thomwolf committed
74
                print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
75
76
77
78
79
80
                    "In that case, it requires TensorFlow to be installed. Please see "
                    "https://www.tensorflow.org/install/ for installation instructions.")
                raise

            if len(sys.argv) < 4 or len(sys.argv) > 5:
                # pylint: disable=line-too-long
thomwolf's avatar
thomwolf committed
81
                print("Should be used as `pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
thomwolf's avatar
thomwolf committed
82
            else:
83
84
85
86
87
88
89
                TF_CHECKPOINT = sys.argv[2]
                PYTORCH_DUMP_OUTPUT = sys.argv[3]
                if len(sys.argv) == 5:
                    TF_CONFIG = sys.argv[4]
                else:
                    TF_CONFIG = ""
                convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
90
        elif sys.argv[1] == "xlnet":
thomwolf's avatar
thomwolf committed
91
            try:
92
                from .convert_xlnet_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch
thomwolf's avatar
thomwolf committed
93
            except ImportError:
thomwolf's avatar
thomwolf committed
94
                print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
thomwolf's avatar
thomwolf committed
95
96
97
98
                    "In that case, it requires TensorFlow to be installed. Please see "
                    "https://www.tensorflow.org/install/ for installation instructions.")
                raise

99
100
            if len(sys.argv) < 5 or len(sys.argv) > 6:
                # pylint: disable=line-too-long
thomwolf's avatar
thomwolf committed
101
                print("Should be used as `pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`")
thomwolf's avatar
thomwolf committed
102
            else:
103
104
105
106
107
                TF_CHECKPOINT = sys.argv[2]
                TF_CONFIG = sys.argv[3]
                PYTORCH_DUMP_OUTPUT = sys.argv[4]
                if len(sys.argv) == 6:
                    FINETUNING_TASK = sys.argv[5]
108
109
                else:
                    FINETUNING_TASK = None
110
111
112
113
114

                convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT,
                                                    TF_CONFIG,
                                                    PYTORCH_DUMP_OUTPUT,
                                                    FINETUNING_TASK)
115
116
117
118
119
120
121
122
123
124
125
        elif sys.argv[1] == "xlm":
            from .convert_xlm_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch

            if len(sys.argv) != 4:
                # pylint: disable=line-too-long
                print("Should be used as `pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT`")
            else:
                XLM_CHECKPOINT_PATH = sys.argv[2]
                PYTORCH_DUMP_OUTPUT = sys.argv[3]

                convert_xlm_checkpoint_to_pytorch(XLM_CHECKPOINT_PATH, PYTORCH_DUMP_OUTPUT)
126

127
128
if __name__ == '__main__':
    main()