loader_mcore.py 15.9 KB
Newer Older
liangjing's avatar
liangjing committed
1
2
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

3
import json
4
5
import os
import sys
liangjing's avatar
liangjing committed
6
import torch
7
8
import types

liangjing's avatar
liangjing committed
9
10
from utils import get_mcore_transformer_block_key, print_memory_usage

11
12
13
14

def add_arguments(parser):
    group = parser.add_argument_group(title='Megatron loader')

15
16
17
18
19
    group.add_argument('--true-vocab-size', type=int, default=None,
                       help='original size of vocab, if specified will trim padding from embedding table.')
    group.add_argument('--vocab-file', type=str, default=None,
                       help='Path to the vocab file. If specified will use this to get vocab size and '
                       'trim padding from the embedding table.')
20
    group.add_argument('--megatron-path', type=str, default=None,
liangjing's avatar
liangjing committed
21
22
23
24
25
26
27
28
29
30
                       help='Base directory of Megatron repository')
    group.add_argument('--position-embedding-type',
                       type=str,
                       default='learned_absolute',
                       choices=['learned_absolute', 'rope'],
                       help='Position embedding type.')
    group.add_argument('--loader-transformer-impl', default='transformer_engine',
                       choices=['local', 'transformer_engine'],
                       help='Which Transformer implementation to use.')

31
32
33
34
35
36
37
38
39
40
41

def _load_checkpoint(queue, args):

    # Search in directory above this
    sys.path.append(os.path.abspath(
        os.path.join(os.path.dirname(__file__),
                     os.path.pardir)))
    if args.megatron_path is not None:
        sys.path.insert(0, args.megatron_path)

    try:
liangjing's avatar
liangjing committed
42
43
44
45
        from megatron.training.arguments import parse_args, validate_args
        from megatron.training.global_vars import set_args, set_global_variables
        from megatron.training.checkpointing import load_args_from_checkpoint, load_checkpoint
        from megatron.legacy.model import module
46
        from megatron.core import mpu
47
        from megatron.core.enums import ModelType
liangjing's avatar
liangjing committed
48
        from megatron.legacy import fused_kernels
49
50
51
52
53
54
55
56
57
58
    except ModuleNotFoundError:
        print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
        queue.put("exit")
        exit(1)

    # We want all arguments to come from us
    sys.argv = ['script.py',
                '--no-masked-softmax-fusion',
                '--no-bias-gelu-fusion',
                '--no-bias-dropout-fusion',
59
                '--no-async-tensor-model-parallel-allreduce',
60
61
62
63
64
65
66
                '--use-cpu-initialization',
                '--micro-batch-size', '1',
                '--no-load-optim',
                '--no-load-rng',
                '--no-save-optim',
                '--no-save-rng',
                '--no-initialization',
liangjing's avatar
liangjing committed
67
68
69
70
                '--mock-data', # To pass the "blend data checks" in arguments.py
                '--load', args.load_dir,
                '--position-embedding-type', args.position_embedding_type,
                '--exit-on-missing-checkpoint',
71
72
                ]

73
    margs = parse_args()
liangjing's avatar
v1  
liangjing committed
74
    margs, checkpoint_args = load_args_from_checkpoint(margs)
75

76
77
78
79
    # Arguments do sanity checks on the world size, but we don't care,
    # so trick it into thinking we are plenty of processes
    margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size

liangjing's avatar
liangjing committed
80
81
82
83
84
    # Explicitly copy data types from checkpoint.
    margs.fp16 = checkpoint_args.fp16
    margs.bf16 = checkpoint_args.bf16

    # Validate margs.
85
86
    margs = validate_args(margs)

liangjing's avatar
liangjing committed
87
88
89
    margs.use_legacy_models = False
    margs.transformer_impl = args.loader_transformer_impl

liangjing's avatar
v1  
liangjing committed
90
    def check_for_arg(arg_name, default=None):
91
        if getattr(margs, arg_name, None) is None:
liangjing's avatar
v1  
liangjing committed
92
93
94
95
96
97
98
            if default is not None:
                setattr(margs, arg_name, default)
            else:
                print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
                print(f"Arguments: {margs}")
                queue.put("exit")
                exit(1)
99
100
101
102
103
104
105
106

    check_for_arg('tensor_model_parallel_size')
    check_for_arg('pipeline_model_parallel_size')
    check_for_arg('num_layers')
    check_for_arg('hidden_size')
    check_for_arg('seq_length')
    check_for_arg('num_attention_heads')
    check_for_arg('max_position_embeddings')
liangjing's avatar
v1  
liangjing committed
107
    check_for_arg('position_embedding_type')
108
109
110
    check_for_arg('tokenizer_type')
    check_for_arg('iteration')
    check_for_arg('bert_binary_head')
liangjing's avatar
v1  
liangjing committed
111
    check_for_arg('disable_bias_linear', False)
112
    check_for_arg('params_dtype')
liangjing's avatar
v1  
liangjing committed
113
    check_for_arg('swiglu', False)
114

115
116
117
118
119
120
121
122
123
124
    # Determine how to make our models
    if args.model_type == 'GPT':
        from pretrain_gpt import model_provider
        margs.model_type = ModelType.encoder_or_decoder
    elif args.model_type == 'BERT':
        from pretrain_bert import model_provider
        margs.model_type = ModelType.encoder_or_decoder
    else:
        raise Exception(f'unrecognized model type: {args.model_type}')

125
126
127
    # supress warning about torch.distributed not being initialized
    module.MegatronModule.embedding_warning_printed = True

128
129
    consumed_train_samples = None
    consumed_valid_samples = None
liangjing's avatar
v1  
liangjing committed
130
    def get_models(count, dtype):
131
132
        nonlocal consumed_train_samples
        nonlocal consumed_valid_samples
liangjing's avatar
v1  
liangjing committed
133
134
135
136
137
138
        model_array_len = margs.virtual_pipeline_model_parallel_size
        if model_array_len is None:
            model_array_len = 1
        models = [[] for _ in range(model_array_len)]
        pre_process = mpu.is_pipeline_first_stage()
        post_process = mpu.is_pipeline_last_stage()
139
        for rank in range(count):
140
            mpu.set_tensor_model_parallel_rank(rank)
liangjing's avatar
v1  
liangjing committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
            if margs.virtual_pipeline_model_parallel_size is not None:
                model_ = []
                for i in range(margs.virtual_pipeline_model_parallel_size):
                    mpu.set_virtual_pipeline_model_parallel_rank(i)
                    # Set pre_process and post_process only after virtual rank is set.
                    pre_process = mpu.is_pipeline_first_stage()
                    post_process = mpu.is_pipeline_last_stage()
                    this_model = model_provider(
                        pre_process=pre_process,
                        post_process=post_process
                    ).to(dtype)
                    model_.append(this_model)
            else:
                pre_process = mpu.is_pipeline_first_stage()
                post_process = mpu.is_pipeline_last_stage()
                model_rank = 0
                model_ = [model_provider(pre_process, post_process).to(dtype)]
158
159
            margs.consumed_train_samples = 0
            margs.consumed_valid_samples = 0
liangjing's avatar
liangjing committed
160
            margs.exit_on_missing_checkpoint = True
161
            load_checkpoint(model_, None, None)
liangjing's avatar
v1  
liangjing committed
162

163
164
165
166
167
168
169
170
            if consumed_train_samples is not None:
                assert(margs.consumed_train_samples == consumed_train_samples)
            else:
                consumed_train_samples = margs.consumed_train_samples
            if consumed_valid_samples is not None:
                assert(margs.consumed_valid_samples == consumed_valid_samples)
            else:
                consumed_valid_samples = margs.consumed_valid_samples
liangjing's avatar
v1  
liangjing committed
171
172
            for vp_rank in range(model_array_len):
                models[vp_rank].append(model_[vp_rank])
liangjing's avatar
liangjing committed
173
174
175
176

            # Print memory usage.
            print_memory_usage("loader", rank, count)

177
178
        return models

liangjing's avatar
v1  
liangjing committed
179
    set_global_variables(margs, build_tokenizer=False)
180
181
    mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
    mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
liangjing's avatar
v1  
liangjing committed
182
    mpu.set_virtual_pipeline_model_parallel_world_size(margs.virtual_pipeline_model_parallel_size)
183
184
    fused_kernels.load(margs)

185
186
187
188
189
190
191
192
193
194
195
196
197
    # Get true (non-padded) vocab size
    if args.true_vocab_size is not None:
        true_vocab_size = args.true_vocab_size
    elif args.vocab_file is not None:
        vocab = json.load(open(args.vocab_file))
        true_vocab_size = len(vocab)
        if args.true_vocab_size is not None and true_vocab_size != args.true_vocab_size:
            print("Both --true-vocab-size and --vocab-file specified and the vocab size does not match, aborting.")
            queue.put("exit")
            exit(1)
    else:
        true_vocab_size = None

198
199
200
    # short aliases
    tp_size = margs.tensor_model_parallel_size
    pp_size = margs.pipeline_model_parallel_size
liangjing's avatar
v1  
liangjing committed
201
202
203
    vp_size = margs.virtual_pipeline_model_parallel_size
    if vp_size is None:
        vp_size = 1
204

liangjing's avatar
liangjing committed
205
206
207
208
209
210
211
    # Layernorm has bias; RMSNorm does not.
    if hasattr(checkpoint_args, 'normalization'):
        norm_has_bias = checkpoint_args.normalization == "LayerNorm"
    else:
        # older models only supported LayerNorm
        norm_has_bias = True

212
213
214
215
216
217
218
219
220
221
222
223
    # metadata
    md = types.SimpleNamespace()
    md.model_type = args.model_type
    md.num_layers = margs.num_layers
    md.hidden_size = margs.hidden_size
    md.seq_length = margs.seq_length
    md.num_attention_heads = margs.num_attention_heads
    md.max_position_embeddings = margs.max_position_embeddings
    md.tokenizer_type = margs.tokenizer_type
    md.iteration = margs.iteration
    md.params_dtype = margs.params_dtype
    md.bert_binary_head = margs.bert_binary_head
liangjing's avatar
v1  
liangjing committed
224
225
226
    md.output_layer = margs.untie_embeddings_and_output_weights
    md.position_embedding_type = margs.position_embedding_type
    md.linear_bias = margs.add_bias_linear
liangjing's avatar
liangjing committed
227
    md.norm_has_bias = norm_has_bias
liangjing's avatar
v1  
liangjing committed
228
    md.swiglu = margs.swiglu
229
230
    md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
    md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
231
232
    md.true_vocab_size = true_vocab_size
    md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
liangjing's avatar
v1  
liangjing committed
233
    md.checkpoint_args = checkpoint_args
liangjing's avatar
liangjing committed
234
235
236
237
238
239
    md.use_legacy_models = margs.use_legacy_models

    # Get transformer block (named either 'encoder' or 'decoder').
    transformer_block_key = get_mcore_transformer_block_key(md.model_type)
    def get_transformer_block(_model):
        return getattr(_model, transformer_block_key)
240
241

    # Get first pipe stage
242
    mpu.set_pipeline_model_parallel_rank(0)
liangjing's avatar
v1  
liangjing committed
243
244
    all_models = [get_models(tp_size, md.params_dtype)]
    models = all_models[0][0]
245

246
247
248
249
    md.consumed_train_samples = consumed_train_samples
    md.consumed_valid_samples = consumed_valid_samples
    queue.put(md)

250
251
252
253
    def queue_put(name, msg):
        print(f"sending {name}")
        msg["name"] = name
        queue.put(msg)
254

255
256
257
    # Send embeddings
    message = {
        "word embeddings": torch.cat(
liangjing's avatar
liangjing committed
258
            [models[tp_rank].embedding.word_embeddings.weight.data for tp_rank in range(tp_size)],
259
260
            dim = 0)
    }
liangjing's avatar
v1  
liangjing committed
261
    if md.position_embedding_type == 'learned_absolute':
liangjing's avatar
liangjing committed
262
        message["position embeddings"] = models[0].embedding.position_embeddings.weight.data
liangjing's avatar
v1  
liangjing committed
263
    else:
liangjing's avatar
liangjing committed
264
        assert not hasattr(models[0].embedding, 'position_embeddings')
265

266
    queue_put("embeddings", message)
267
268

    total_layer_num = 0
liangjing's avatar
v1  
liangjing committed
269
270
271
272
273
274
275
276
    for vp_rank in range(vp_size):
        mpu.set_virtual_pipeline_model_parallel_rank(vp_rank)
        for pp_rank in range(pp_size):
            if pp_rank > 0:
                mpu.set_pipeline_model_parallel_rank(pp_rank)
                if vp_rank == 0:
                    all_models.append(get_models(tp_size, md.params_dtype))
            models = all_models[pp_rank][vp_rank]
liangjing's avatar
liangjing committed
277
            for layer_num in range(len(get_transformer_block(models[0]).layers)):
liangjing's avatar
v1  
liangjing committed
278
279
280
                message = {}

                # Get non-parallel tensors from tp_rank 0
liangjing's avatar
liangjing committed
281
282
283
284
285
286
287
                layer = get_transformer_block(models[0]).layers[layer_num]
                message["input norm weight"] = layer.self_attention.linear_qkv.layer_norm_weight.data
                if norm_has_bias:
                    message["input norm bias"] = layer.self_attention.linear_qkv.layer_norm_bias.data
                message["post norm weight"] = layer.mlp.linear_fc1.layer_norm_weight.data
                if norm_has_bias:
                    message["post norm bias"] = layer.mlp.linear_fc1.layer_norm_bias.data
liangjing's avatar
v1  
liangjing committed
288
                if md.linear_bias:
liangjing's avatar
liangjing committed
289
290
                    message["dense bias"] = layer.self_attention.linear_proj.bias.data
                    message["mlp l1 bias"] = layer.mlp.linear_fc2.bias.data
liangjing's avatar
v1  
liangjing committed
291
292
293
294
295
296
297
298
299

                # Grab all parallel tensors for this layer
                qkv_weight = []
                qkv_bias = []
                dense_weight = []
                mlp_l0_weight = []
                mlp_l0_bias = []
                mlp_l1_weight = []
                for tp_rank, model in enumerate(models):
liangjing's avatar
liangjing committed
300
301
302
303
304
                    layer = get_transformer_block(model).layers[layer_num]
                    qkv_weight.append(layer.self_attention.linear_qkv.weight.data)
                    dense_weight.append(layer.self_attention.linear_proj.weight.data)
                    mlp_l0_weight.append(layer.mlp.linear_fc1.weight.data)
                    mlp_l1_weight.append(layer.mlp.linear_fc2.weight.data)
liangjing's avatar
v1  
liangjing committed
305
                    if md.linear_bias:
liangjing's avatar
liangjing committed
306
307
                        qkv_bias.append(layer.self_attention.linear_qkv.bias.data)
                        mlp_l0_bias.append(layer.mlp.linear_fc1.bias.data)
liangjing's avatar
v1  
liangjing committed
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335

                # Handle gated linear units
                if md.swiglu:
                    # concat all the first halves ('W's) and all the second halves ('V's)
                    for tp_rank in range(tp_size):
                        mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0)
                    message["mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0)
                    message["mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0)
                else:
                    message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)

                # simple concat of the rest
                message["qkv weight"] = torch.cat(qkv_weight, dim=0)
                message["dense weight"] = torch.cat(dense_weight, dim=1)
                message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
                if md.linear_bias:
                    message["qkv bias"] = torch.cat(qkv_bias, dim=0)
                    if md.swiglu:
                        for tp_rank in range(tp_size):
                            mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0)
                        message["mlp l0 bias W"] = torch.cat([b[0] for b in mlp_l0_bias],dim=0)
                        message["mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias],dim=0)
                    else:
                        message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)

                queue_put(f"transformer layer {total_layer_num}", message)

                total_layer_num = total_layer_num + 1
336

liangjing's avatar
liangjing committed
337
    # Send final norm from tp_rank 0
338
    message = {
liangjing's avatar
liangjing committed
339
        "weight": get_transformer_block(models[0]).final_layernorm.weight.data,
340
    }
liangjing's avatar
liangjing committed
341
342
343
    if norm_has_bias:
        message["bias"] = get_transformer_block(models[0]).final_layernorm.bias.data
    queue_put("final norm", message)
344

liangjing's avatar
v1  
liangjing committed
345
346
347
    if md.output_layer:
        message = {
            "weight": torch.cat(
liangjing's avatar
liangjing committed
348
                [models[tp_rank].output_layer.weight.data for tp_rank in range(tp_size)],
liangjing's avatar
v1  
liangjing committed
349
350
351
352
353
                dim = 0)
        }
        queue_put("output layer", message)


354
355
    # Send BERT lm head and binary head if it exists
    if md.model_type == 'BERT':
356
        message = {
liangjing's avatar
liangjing committed
357
358
            "weight": models[0].pooler.dense.weight.data,
            "bias": models[0].pooler.dense.bias.data
359
360
361
362
363
364
        }
        queue_put("pooler", message)

        message = {
            "dense weight": models[0].lm_head.dense.weight.data,
            "dense bias": models[0].lm_head.dense.bias.data,
liangjing's avatar
liangjing committed
365
            "norm weight": models[0].lm_head.layer_norm.weight.data,
366
        }
liangjing's avatar
liangjing committed
367
368
        if norm_has_bias:
            message["norm bias"] = models[0].lm_head.layer_norm.bias.data
369
        queue_put("lm head", message)
370
371

        if md.bert_binary_head:
372
373
374
375
376
            message = {
                "weight": models[0].binary_head.weight.data,
                "bias": models[0].binary_head.bias.data
            }
            queue_put("binary head", message)
377
378
379
380
381
    queue.put("done")

def load_checkpoint(queue, args):
    try:
        _load_checkpoint(queue, args)
liangjing's avatar
liangjing committed
382
    except Exception:
383
384
        queue.put("exit")
        raise