__main__.py 2.08 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding: utf8
2
def main():
thomwolf's avatar
thomwolf committed
3
    import sys
thomwolf's avatar
thomwolf committed
4
5
6
7
8
    if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [
        "convert_tf_checkpoint_to_pytorch",
        "convert_openai_checkpoint"
    ]:
        print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT` \n or `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`")
thomwolf's avatar
thomwolf committed
9
    else:
thomwolf's avatar
thomwolf committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
        if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
            try:
                from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
            except ModuleNotFoundError:
                print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
                    "In that case, it requires TensorFlow to be installed. Please see "
                    "https://www.tensorflow.org/install/ for installation instructions.")
                raise

            if len(sys.argv) != 5:
                # pylint: disable=line-too-long
                print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
            else:
                PYTORCH_DUMP_OUTPUT = sys.argv.pop()
                TF_CONFIG = sys.argv.pop()
                TF_CHECKPOINT = sys.argv.pop()
                convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
        else:
            from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
            OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
            PYTORCH_DUMP_OUTPUT = sys.argv[3]
            if len(sys.argv) == 5:
                OPENAI_GPT_CONFIG = sys.argv[4]
            else:
                OPENAI_GPT_CONFIG = ""
            convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
                                                 OPENAI_GPT_CONFIG,
                                                 PYTORCH_DUMP_OUTPUT)
38
39
40

if __name__ == '__main__':
    main()