checkpoint_loader_megatron.py 10.8 KB
Newer Older
1
import json
2
3
4
5
6
7
8
9
10
import os
import sys
import types

import torch

def add_arguments(parser):
    group = parser.add_argument_group(title='Megatron loader')

11
12
13
14
15
    group.add_argument('--true-vocab-size', type=int, default=None,
                       help='original size of vocab, if specified will trim padding from embedding table.')
    group.add_argument('--vocab-file', type=str, default=None,
                       help='Path to the vocab file. If specified will use this to get vocab size and '
                       'trim padding from the embedding table.')
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    group.add_argument('--megatron-path', type=str, default=None,
                       help='Base directory of deepspeed repository')

def _load_checkpoint(queue, args):

    # Search in directory above this
    sys.path.append(os.path.abspath(
        os.path.join(os.path.dirname(__file__),
                     os.path.pardir)))
    if args.megatron_path is not None:
        sys.path.insert(0, args.megatron_path)

    try:
        from megatron.arguments import parse_args, validate_args
30
        from megatron.global_vars import set_args, set_global_variables
31
        from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint
32
        from megatron.model import module
33
        from megatron.core import mpu
34
        from megatron.core.enums import ModelType
35
        from megatron import fused_kernels
36
37
38
39
40
41
42
43
44
45
    except ModuleNotFoundError:
        print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
        queue.put("exit")
        exit(1)

    # We want all arguments to come from us
    sys.argv = ['script.py',
                '--no-masked-softmax-fusion',
                '--no-bias-gelu-fusion',
                '--no-bias-dropout-fusion',
46
                '--no-async-tensor-model-parallel-allreduce',
47
48
49
50
51
52
53
54
55
56
                '--use-cpu-initialization',
                '--micro-batch-size', '1',
                '--no-load-optim',
                '--no-load-rng',
                '--no-save-optim',
                '--no-save-rng',
                '--no-initialization',
                '--load', args.load_dir
                ]

57
    margs = parse_args()
58
59
    margs = load_args_from_checkpoint(margs)

60
61
62
63
64
65
    # Arguments do sanity checks on the world size, but we don't care,
    # so trick it into thinking we are plenty of processes
    margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size

    margs = validate_args(margs)

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    def check_for_arg(arg_name):
        if getattr(margs, arg_name, None) is None:
            print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
            print(f"Arguments: {margs}")
            queue.put("exit")
            exit(1)

    check_for_arg('tensor_model_parallel_size')
    check_for_arg('pipeline_model_parallel_size')
    check_for_arg('num_layers')
    check_for_arg('hidden_size')
    check_for_arg('seq_length')
    check_for_arg('num_attention_heads')
    check_for_arg('max_position_embeddings')
    check_for_arg('tokenizer_type')
    check_for_arg('iteration')
    check_for_arg('bert_binary_head')
    check_for_arg('params_dtype')

85
86
87
88
89
90
91
92
93
94
    # Determine how to make our models
    if args.model_type == 'GPT':
        from pretrain_gpt import model_provider
        margs.model_type = ModelType.encoder_or_decoder
    elif args.model_type == 'BERT':
        from pretrain_bert import model_provider
        margs.model_type = ModelType.encoder_or_decoder
    else:
        raise Exception(f'unrecognized model type: {args.model_type}')

95
96
97
    # supress warning about torch.distributed not being initialized
    module.MegatronModule.embedding_warning_printed = True

98
99
    consumed_train_samples = None
    consumed_valid_samples = None
100
    def get_models(count, dtype, pre_process, post_process):
101
102
        nonlocal consumed_train_samples
        nonlocal consumed_valid_samples
103
104
        models = []
        for rank in range(count):
105
            mpu.set_tensor_model_parallel_rank(rank)
106
107
108
109
110
            model_ = [model_provider(pre_process, post_process).to(dtype)]
            margs.consumed_train_samples = 0
            margs.consumed_valid_samples = 0
            load_checkpoint(model_, None, None)
            assert(len(model_) == 1)
111
112
113
114
115
116
117
118
119
120
            model_ = model_[0]
            if consumed_train_samples is not None:
                assert(margs.consumed_train_samples == consumed_train_samples)
            else:
                consumed_train_samples = margs.consumed_train_samples
            if consumed_valid_samples is not None:
                assert(margs.consumed_valid_samples == consumed_valid_samples)
            else:
                consumed_valid_samples = margs.consumed_valid_samples
            models.append(model_)
121
122
        return models

123
124
125
126
127
    if margs.num_layers_per_virtual_pipeline_stage is not None:
        print("Model with an interleaved pipeline schedule are not yet supported.")
        queue.put("exit")
        exit(1)

128
    set_global_variables(margs)
129
130
    mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
    mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
131
132
    fused_kernels.load(margs)

133
134
135
136
137
138
139
140
141
142
143
144
145
    # Get true (non-padded) vocab size
    if args.true_vocab_size is not None:
        true_vocab_size = args.true_vocab_size
    elif args.vocab_file is not None:
        vocab = json.load(open(args.vocab_file))
        true_vocab_size = len(vocab)
        if args.true_vocab_size is not None and true_vocab_size != args.true_vocab_size:
            print("Both --true-vocab-size and --vocab-file specified and the vocab size does not match, aborting.")
            queue.put("exit")
            exit(1)
    else:
        true_vocab_size = None

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    # short aliases
    tp_size = margs.tensor_model_parallel_size
    pp_size = margs.pipeline_model_parallel_size

    # metadata
    md = types.SimpleNamespace()
    md.model_type = args.model_type
    md.num_layers = margs.num_layers
    md.hidden_size = margs.hidden_size
    md.seq_length = margs.seq_length
    md.num_attention_heads = margs.num_attention_heads
    md.max_position_embeddings = margs.max_position_embeddings
    md.tokenizer_type = margs.tokenizer_type
    md.iteration = margs.iteration
    md.params_dtype = margs.params_dtype
    md.bert_binary_head = margs.bert_binary_head
    md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
    md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
164
165
    md.true_vocab_size = true_vocab_size
    md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
166
167

    # Get first pipe stage
168
    mpu.set_pipeline_model_parallel_rank(0)
169
170
171
    post_process = pp_size == 1
    models = get_models(tp_size, md.params_dtype, True, post_process)

172
173
174
175
    md.consumed_train_samples = consumed_train_samples
    md.consumed_valid_samples = consumed_valid_samples
    queue.put(md)

176
177
178
179
    def queue_put(name, msg):
        print(f"sending {name}")
        msg["name"] = name
        queue.put(msg)
180

181
182
183
184
185
186
187
    # Send embeddings
    message = {
        "position embeddings": models[0].language_model.embedding.position_embeddings.weight.data,
        "word embeddings": torch.cat(
            [models[tp_rank].language_model.embedding.word_embeddings.weight.data for tp_rank in range(tp_size)],
            dim = 0)
    }
188

189
    queue_put("embeddings", message)
190
191
192
193

    total_layer_num = 0
    for pp_rank in range(pp_size):
        if pp_rank > 0:
194
            mpu.set_pipeline_model_parallel_rank(pp_rank)
195
196
197
            post_process = pp_rank == pp_size - 1
            models = get_models(tp_size, md.params_dtype, False, post_process)
        for layer_num in range(len(models[0].language_model.encoder.layers)):
198
199
200
201
202
203
204
205
206
207
208
209
            message = {}

            # Get non-parallel tensors from tp_rank 0
            layer = models[0].language_model.encoder.layers[layer_num]
            message["input layernorm weight"] = layer.input_layernorm.weight.data
            message["input layernorm bias"] = layer.input_layernorm.bias.data
            message["dense bias"] = layer.self_attention.dense.bias.data
            message["post layernorm weight"] = layer.post_attention_layernorm.weight.data
            message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
            message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data

            # Grab all parallel tensors for this layer
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
            qkv_weight = []
            qkv_bias = []
            dense_weight = []
            mlp_l0_weight = []
            mlp_l0_bias = []
            mlp_l1_weight = []
            for tp_rank, model in enumerate(models):
                layer = model.language_model.encoder.layers[layer_num]
                qkv_weight.append(layer.self_attention.query_key_value.weight.data)
                qkv_bias.append(layer.self_attention.query_key_value.bias.data)
                dense_weight.append(layer.self_attention.dense.weight.data)
                mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
                mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
                mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)

225
226
227
228
229
230
231
232
233
            # concat them
            message["qkv weight"] = torch.cat(qkv_weight, dim=0)
            message["qkv bias"] = torch.cat(qkv_bias, dim=0)
            message["dense weight"] = torch.cat(dense_weight, dim=1)
            message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
            message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
            message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)

            queue_put(f"transformer layer {total_layer_num}", message)
234
235
236
237

            total_layer_num = total_layer_num + 1

    # Send final layernorm from tp_rank 0
238
239
240
241
242
    message = {
        "weight": models[0].language_model.encoder.final_layernorm.weight.data,
        "bias": models[0].language_model.encoder.final_layernorm.bias.data
    }
    queue_put("final layernorm", message)
243
244
245

    # Send BERT lm head and binary head if it exists
    if md.model_type == 'BERT':
246
247
248
249
250
251
252
253
254
255
256
257
258
        message = {
            "weight": models[0].language_model.pooler.dense.weight.data,
            "bias": models[0].language_model.pooler.dense.bias.data
        }
        queue_put("pooler", message)

        message = {
            "dense weight": models[0].lm_head.dense.weight.data,
            "dense bias": models[0].lm_head.dense.bias.data,
            "layernorm weight": models[0].lm_head.layernorm.weight.data,
            "layernorm bias": models[0].lm_head.layernorm.bias.data
        }
        queue_put("lm head", message)
259
260

        if md.bert_binary_head:
261
262
263
264
265
            message = {
                "weight": models[0].binary_head.weight.data,
                "bias": models[0].binary_head.bias.data
            }
            queue_put("binary head", message)
266
267
268
269
270
271
272
273
    queue.put("done")

def load_checkpoint(queue, args):
    try:
        _load_checkpoint(queue, args)
    except:
        queue.put("exit")
        raise