convert.py 4.88 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

3
4
5
import argparse
import importlib
import torch.multiprocessing as mp
6
import sys
7
8
9
10
11
12
13
14
15

# A loader is a python file with at least two functions
# - add_arguments - takes in a parser and adds any arguments needed
# - load_checkpoint - takes in the queue and parsed arguments

# A saver is similar but has save_checkpoint instead of
# load_checkpoint

# The loader and saver process are each given a queue, the loader
16
17
18
19
20
21
# should load the checkpoint and send the weights in messages in the
# following order, the saver should receive them in this order and
# save the checkpoints. A message consists of a python dictionary with
# a "name" for error checking and an entry for each tensor as
# indicated below. Note that the weight sent over the queue are the
# full model weights, nothing split.
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

# If the loader ever sends "exit" to the queue, that means something
# went wrong and it is exiting.

# - Metadata Namespace with the following attributes:
#     model_type - GPT, BERT, T5, etc.  (Part of protocol to allow this to be deduced later instead of given on command line)
#     num_layers - Number of transformer layers
#     hidden_size
#     seq_length
#     num_attention_heads
#     max_position_embeddings
#     tokenizer_type
#     iteration
#     params_dtype
#     bert_binary_head - Used only if model_type is BERT
#     previous_tensor_parallel_size - Optional
#     previous_pipeline_parallel_size - Optional
39
40
41
42
#     true_vocab_size
#     make_vocab_size_divisble_by
#     consumed_train_samples
#     consumed_valid_samples
43
44
45
46
47
48
49
50
51
# messages
# {
#   "name": "embeddings"
#   "position embeddings"
#   "word embeddings"
# }
# (for each transformer layer):
# {
#   "name": "transformer layer N"
xingjinliang's avatar
xingjinliang committed
52
53
#   "input norm weight"
#   "input norm bias"
54
55
56
57
#   "qkv weight"
#   "qkv bias"
#   "dense weight"
#   "dense bias"
xingjinliang's avatar
xingjinliang committed
58
59
#   "post norm weight"
#   "post norm bias"
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#   "mlp l0 weight"
#   "mlp l0 bias"
#   "mlp l1 weight"
#   "mlp l1 bias"
# }
# {
#   "name": "final layer norm"
#   "weight"
#   "bias"
# }
# if present (i.e. for BERT):
# {
#   "name": "pooler"
#   "weight"
#   "bias"
# }
# {
#   "name": "lm head"
#   "dense weight"
#   "dense bias"
xingjinliang's avatar
xingjinliang committed
80
81
#   "norm weight"
#   "norm bias"
82
83
84
85
86
87
# }
# {
#   "name": "binary head"
#   "weight"
#   "bias"
# }
88
89
90
# - "done"

def load_plugin(plugin_type, name):
xingjinliang's avatar
xingjinliang committed
91
    module_name = f"{plugin_type}_{name}"
92
93
    try:
        plugin = importlib.import_module(module_name)
xingjinliang's avatar
xingjinliang committed
94
95
    except ModuleNotFoundError as e:
        print(e)
96
97
98
        module_name = name
        try:
            plugin = importlib.import_module(module_name)
xingjinliang's avatar
xingjinliang committed
99
100
        except ModuleNotFoundError as e:
            print(e)
101
            sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.")
102
103

    if not hasattr(plugin, 'add_arguments'):
104
        sys.exit(f"{module_name} module is not a plugin. Exiting.")
105
106
107
108
109
110

    print(f"Loaded {module_name} as the {plugin_type}.")
    return plugin

def main():
    import argparse
xingjinliang's avatar
xingjinliang committed
111
    parser = argparse.ArgumentParser(description="Megatron Checkpoint Converter Arguments",
112
113
114
115
116
117
118
119
                                     allow_abbrev=False, conflict_handler='resolve')

    parser.add_argument('--model-type', type=str, required=True,
                        choices=['GPT', 'BERT'],
                        help='Type of the model')
    parser.add_argument('--loader', type=str, default='megatron',
                        help='Module name to load checkpoint, should be on python path')
    parser.add_argument('--saver', type=str, default='megatron',
xingjinliang's avatar
xingjinliang committed
120
                        help='Module name to save checkpoint, should be on python path')
121
122
123
124
125
126
    parser.add_argument('--load-dir', type=str, required=True,
                        help='Directory to load model checkpoint from')
    parser.add_argument('--save-dir', type=str, required=True,
                        help='Directory to save model checkpoint to')
    parser.add_argument('--max-queue-size', type=int, default=50,
                        help='Maximum number of tensors in the queue')
127
128
129
    parser.add_argument('--no-checking', action='store_false',
                        help='Do not perform checking on the name and ordering of weights',
                        dest='checking')
130
131
132
133
134
135
136
137
138
139

    known_args, _ = parser.parse_known_args()
    loader = load_plugin('loader', known_args.loader)
    saver = load_plugin('saver', known_args.saver)

    loader.add_arguments(parser)
    saver.add_arguments(parser)

    args = parser.parse_args()

140
141
142
    ctx = mp.get_context("spawn")
    queue = ctx.Queue(maxsize=args.max_queue_size)
    # queue = mp.Queue(maxsize=args.max_queue_size)
143
144

    print("Starting saver...")
145
146
    saver_proc = ctx.Process(target=saver.save_checkpoint, args=(queue, args))
    # saver_proc = mp.Process(target=saver.save_checkpoint, args=(queue, args))
147
148
149
150
151
152
153
154
155
156
157
    saver_proc.start()

    print("Starting loader...")
    loader.load_checkpoint(queue, args)

    print("Waiting for saver to complete...")
    saver_proc.join()


if __name__ == '__main__':
    main()