__main__.py 4.29 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding: utf8
2
def main():
thomwolf's avatar
thomwolf committed
3
    import sys
thomwolf's avatar
thomwolf committed
4
5
    if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [
        "convert_tf_checkpoint_to_pytorch",
thomwolf's avatar
thomwolf committed
6
        "convert_openai_checkpoint",
thomwolf's avatar
thomwolf committed
7
8
        "convert_transfo_xl_checkpoint",
        "convert_gpt2_checkpoint",
thomwolf's avatar
thomwolf committed
9
    ]:
thomwolf's avatar
thomwolf committed
10
        print(
thomwolf's avatar
thomwolf committed
11
12
        "Should be used as one of: \n"
        ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n"
thomwolf's avatar
thomwolf committed
13
14
15
        ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n"
        ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n"
        ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`")
thomwolf's avatar
thomwolf committed
16
    else:
thomwolf's avatar
thomwolf committed
17
18
        if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
            try:
thomwolf's avatar
thomwolf committed
19
                from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
thomwolf's avatar
thomwolf committed
20
            except ImportError:
thomwolf's avatar
thomwolf committed
21
22
23
24
25
26
27
28
29
30
31
32
33
                print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
                    "In that case, it requires TensorFlow to be installed. Please see "
                    "https://www.tensorflow.org/install/ for installation instructions.")
                raise

            if len(sys.argv) != 5:
                # pylint: disable=line-too-long
                print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
            else:
                PYTORCH_DUMP_OUTPUT = sys.argv.pop()
                TF_CONFIG = sys.argv.pop()
                TF_CHECKPOINT = sys.argv.pop()
                convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
thomwolf's avatar
thomwolf committed
34
        elif sys.argv[1] == "convert_openai_checkpoint":
thomwolf's avatar
thomwolf committed
35
36
37
38
39
40
41
42
43
44
            from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
            OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
            PYTORCH_DUMP_OUTPUT = sys.argv[3]
            if len(sys.argv) == 5:
                OPENAI_GPT_CONFIG = sys.argv[4]
            else:
                OPENAI_GPT_CONFIG = ""
            convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
                                                 OPENAI_GPT_CONFIG,
                                                 PYTORCH_DUMP_OUTPUT)
thomwolf's avatar
thomwolf committed
45
        elif sys.argv[1] == "convert_transfo_xl_checkpoint":
thomwolf's avatar
thomwolf committed
46
            try:
thomwolf's avatar
thomwolf committed
47
                from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
thomwolf's avatar
thomwolf committed
48
            except ImportError:
thomwolf's avatar
thomwolf committed
49
50
51
52
53
                print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
                    "In that case, it requires TensorFlow to be installed. Please see "
                    "https://www.tensorflow.org/install/ for installation instructions.")
                raise

thomwolf's avatar
thomwolf committed
54
55
56
57
58
59
            if 'ckpt' in sys.argv[2].lower():
                TF_CHECKPOINT = sys.argv[2]
                TF_DATASET_FILE = ""
            else:
                TF_DATASET_FILE = sys.argv[2]
                TF_CHECKPOINT = ""
thomwolf's avatar
thomwolf committed
60
61
62
63
64
            PYTORCH_DUMP_OUTPUT = sys.argv[3]
            if len(sys.argv) == 5:
                TF_CONFIG = sys.argv[4]
            else:
                TF_CONFIG = ""
thomwolf's avatar
thomwolf committed
65
            convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
thomwolf's avatar
thomwolf committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        else:
            try:
                from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
            except ImportError:
                print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
                    "In that case, it requires TensorFlow to be installed. Please see "
                    "https://www.tensorflow.org/install/ for installation instructions.")
                raise

            TF_CHECKPOINT = sys.argv[2]
            PYTORCH_DUMP_OUTPUT = sys.argv[3]
            if len(sys.argv) == 5:
                TF_CONFIG = sys.argv[4]
            else:
                TF_CONFIG = ""
            convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
82
83
if __name__ == '__main__':
    main()