saver_megatron.py 19.7 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

3
4
5
6
import os
import sys
import torch

xingjinliang's avatar
xingjinliang committed
7

8
9
10
11
12
13
14
15
def add_arguments(parser):
    group = parser.add_argument_group(title='Megatron saver')

    group.add_argument('--megatron-path', type=str, default=None,
                       help='Base directory of Megatron repository')

    group.add_argument('--target-tensor-parallel-size', type=int,
                       help='Target tensor model parallel size, defaults to the tensor parallel size '
xingjinliang's avatar
xingjinliang committed
16
                            'in the input checkpoint if provided by the loader, otherwise to 1')
17
18
19
    group.add_argument('--target-pipeline-parallel-size', type=int,
                       help='Target tensor model parallel size, default to the pipeline parall size '
                       'in the input checkpoint if provided by the loader, otherwise to 1')
xingjinliang's avatar
xingjinliang committed
20
21
22
    group.add_argument('--saver-transformer-impl', default='local',
                       choices=['local', 'transformer_engine'],
                       help='Which Transformer implementation to use.')
23
24
25
26
27

def save_checkpoint(queue, args):
    # Search in directory above this
    sys.path.append(os.path.abspath(
        os.path.join(os.path.dirname(__file__),
xingjinliang's avatar
xingjinliang committed
28
                     os.path.pardir,
29
30
31
32
33
                     os.path.pardir)))
    if args.megatron_path is not None:
        sys.path.insert(0, args.megatron_path)

    try:
xingjinliang's avatar
xingjinliang committed
34
35
36
        from megatron.training.arguments import (parse_args, validate_args)
        from megatron.training.checkpointing import save_checkpoint
        from megatron.training.global_vars import set_global_variables, get_args
37
        from megatron.core.enums import ModelType
xingjinliang's avatar
xingjinliang committed
38
39
        from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding
        from megatron.legacy import fused_kernels
40
        from megatron.core import mpu
41
42
43
44
    except ModuleNotFoundError:
        print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
        exit(1)

45
    def queue_get(name=None):
46
47
48
49
        val = queue.get()
        if val == "exit":
            print("Loader exited, exiting saver")
            exit(1)
50
51
52
53
54
55
        if name is not None and args.checking and val["name"] != name:
            val_name = val["name"]
            print(f'Unexpected message. Expecting "{name}" but got "{val_name}". Exiting saver.')
            exit(1)
        if name is not None:
            print(f"received {name}")
56
57
        return val

58
59
60
61
62
63
64
65
66
67
68
    def check_message(msg):
        if not args.checking:
            return
        msg_name = msg.pop("name")
        if len(msg.keys()) > 0:
            print(f"Unexpected values in {msg_name}:")
            for key in msg.keys():
                print(f"   {key}")
            print(f"Exiting. If you want to ignore this, use the argument --no-checking.")
            exit(1)

69
70
71
72
73
74
    md = queue_get()

    if args.target_tensor_parallel_size is None:
        if hasattr(md, 'previous_tensor_parallel_size'):
            args.target_tensor_parallel_size = md.previous_tensor_parallel_size
        else:
xingjinliang's avatar
xingjinliang committed
75
76
77
            print(
                "loader did not provide a tensor parallel size and --target-tensor-parallel-size not provided on command line. "
                "Default to 1.")
78
79
80
81
82
83
            args.target_tensor_parallel_size = 1

    if args.target_pipeline_parallel_size is None:
        if hasattr(md, 'previous_pipeline_parallel_size'):
            args.target_pipeline_parallel_size = md.previous_pipeline_parallel_size
        else:
xingjinliang's avatar
xingjinliang committed
84
85
86
            print(
                "loader did not provide a pipeline parallel size and --target-pipeline-parallel-size not provided on command line. "
                "Default to 1.")
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            args.target_pipeline_parallel_size = 1

    # Arguments do sanity checks on the world size, but we don't care,
    # so trick it into thinking we are plenty of processes
    if args.target_tensor_parallel_size is not None and args.target_pipeline_parallel_size is not None:
        os.environ["WORLD_SIZE"] = f'{args.target_tensor_parallel_size * args.target_pipeline_parallel_size}'

    # We want all arguments to come from us
    sys.argv = ['script.py',
                '--num-layers', str(md.num_layers),
                '--hidden-size', str(md.hidden_size),
                '--seq-length', str(md.seq_length),
                '--num-attention-heads', str(md.num_attention_heads),
                '--max-position-embeddings', str(md.max_position_embeddings),
liangjing's avatar
v1  
liangjing committed
101
                '--position-embedding-type', str(md.position_embedding_type),
102
103
104
105
106
107
                '--tokenizer-type', str(md.tokenizer_type),
                '--tensor-model-parallel-size', str(args.target_tensor_parallel_size),
                '--pipeline-model-parallel-size', str(args.target_pipeline_parallel_size),
                '--no-masked-softmax-fusion',
                '--no-bias-gelu-fusion',
                '--no-bias-dropout-fusion',
108
                '--no-async-tensor-model-parallel-allreduce',
109
110
111
112
113
114
115
116
                '--use-cpu-initialization',
                '--micro-batch-size', '1',
                '--no-load-optim',
                '--no-load-rng',
                '--no-save-optim',
                '--no-save-rng',
                '--no-initialization',
                '--save-interval', '1',
xingjinliang's avatar
xingjinliang committed
117
118
119
                '--save', args.save_dir,
                '--ckpt-format', 'torch', # only 'torch' supported for conversion
                '--no-one-logger',
120
                ]
121
122
123

    if md.make_vocab_size_divisible_by is not None:
        sys.argv.extend(['--make-vocab-size-divisible-by', str(md.make_vocab_size_divisible_by)])
124
125
126
127
128
    if md.params_dtype == torch.float16:
        sys.argv.append('--fp16')
    elif md.params_dtype == torch.bfloat16:
        sys.argv.append('--bf16')

liangjing's avatar
v1  
liangjing committed
129
130
131
132
133
    if md.output_layer:
        sys.argv.append('--untie-embeddings-and-output-weights')
    if not md.linear_bias:
        sys.argv.append('--disable-bias-linear')

134
135
    if md.model_type == 'BERT' and not md.bert_binary_head:
        sys.argv.append('--bert-no-binary-head')
136
137

    margs = parse_args()
liangjing's avatar
v1  
liangjing committed
138

xingjinliang's avatar
xingjinliang committed
139
    if hasattr(md, 'checkpoint_args'):
liangjing's avatar
v1  
liangjing committed
140
141
142
143
144
145
146
147
148
149
        # These are arguments that we are either changing, or cause problems for validation if they are set
        # Note that some of these deal with T5 so will need to be changed if we support T5.
        args_to_keep = ['tensor_model_parallel_size', 'pipeline_model_parallel_size', 'world_size', 'params_dtype',
                        'num_layers_per_virtual_pipeline_stage', 'virtual_pipeline_model_parallel_size',
                        'masked_softmax_fusion', 'bias_gelu_fusion', 'bias_dropout_fusion',
                        'sequence_parallel', 'async_tensor_model_parallel_allreduce',
                        'no_load_optim', 'no_load_rng', 'no_save_optim', 'no_save_rng',
                        'vocab_file', 'tokenizer_model',
                        'save_interval', 'save',
                        'perform_initialization', 'use_cpu_initialization',
xingjinliang's avatar
xingjinliang committed
150
                        'recompute_granularity', 'recompute_num_layers', 'recompute_method',
liangjing's avatar
v1  
liangjing committed
151
152
153
                        'encoder_num_layers', 'encoder_seq_length',
                        'distribute_saved_activations',
                        'train_iters', 'lr_decay_iters', 'lr_warmup_iters', 'lr_warmup_fraction',
xingjinliang's avatar
xingjinliang committed
154
155
156
                        'start_weight_decay', 'end_weight_decay', 'bf16', 'fp16',
                        'ckpt_format',
        ]
liangjing's avatar
v1  
liangjing committed
157
158
159
160
161
162
163
164
165
166
167

        for arg, value in vars(md.checkpoint_args).items():
            if arg in args_to_keep:
                continue
            if not hasattr(margs, arg):
                print(f"Checkpoint had argument {arg} but new arguments does not have this.")
                continue
            if getattr(margs, arg) != value:
                print(f"Overwriting default {arg} value {getattr(margs, arg)} with value from checkpoint {value}.")
                setattr(margs, arg, value)

168
    validate_args(margs)
liangjing's avatar
v1  
liangjing committed
169

xingjinliang's avatar
xingjinliang committed
170
171
172
173
174
175
176
    # Use MLM models.
    margs.use_legacy_models = True
    margs.transformer_impl = args.saver_transformer_impl

    # Do not instantiate Tensorboard
    margs.tensorboard_dir = None

liangjing's avatar
v1  
liangjing committed
177
    set_global_variables(margs, build_tokenizer=False)
178
179
180
181

    # margs = megatron args
    margs = get_args()

182
183
184
185
186
187
188
189
    if hasattr(md, 'consumed_train_samples'):
        margs.consumed_train_samples = md.consumed_train_samples
        margs.consumed_valid_samples = md.consumed_valid_samples
        print(f"Setting consumed_train_samples to {margs.consumed_train_samples}"
              f" and consumed_valid_samples to {margs.consumed_valid_samples}")
    else:
        print("consumed_train_samples not provided.")

190
191
192
193
194
195
196
197
198
199
200
201
202
203
    # Determine how to make our models
    if md.model_type == 'GPT':
        from pretrain_gpt import model_provider
        margs.model_type = ModelType.encoder_or_decoder
    elif md.model_type == 'BERT':
        from pretrain_bert import model_provider
        margs.model_type = ModelType.encoder_or_decoder
    else:
        raise Exception(f'unrecognized model type: {args.model_type}')

    def get_models(count, dtype, pre_process, post_process):
        models = [model_provider(pre_process, post_process).to(dtype) for _ in range(count)]
        return models

204
    # fake initializing distributed
205
206
207
208
    mpu.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size)
    mpu.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size)
    mpu.set_tensor_model_parallel_rank(0)
    mpu.set_pipeline_model_parallel_rank(0)
209
210
211
    fused_kernels.load(margs)

    # Embeddings
xingjinliang's avatar
xingjinliang committed
212
    # -----------
213
214
    embeddings_msg = queue_get("embeddings")

liangjing's avatar
v1  
liangjing committed
215
216
217
    pos_embed = None
    if md.position_embedding_type == 'learned_absolute':
        pos_embed = embeddings_msg.pop("position embeddings")
218
219
    orig_word_embed = embeddings_msg.pop("word embeddings")
    check_message(embeddings_msg)
220

221
222
223
224
225
226
227
228
    # Deal with padding
    if md.true_vocab_size is not None:
        # figure out what our padded vocab size is
        orig_vocab_size = orig_word_embed.shape[0]
        margs.padded_vocab_size = _vocab_size_with_padding(md.true_vocab_size, margs)

        # Cut out extra padding we don't need
        if orig_vocab_size > margs.padded_vocab_size:
xingjinliang's avatar
xingjinliang committed
229
            full_word_embed = orig_word_embed[0:margs.padded_vocab_size, :]
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

        # Expanding embedding to larger size by replicating final entry
        elif orig_vocab_size < margs.padded_vocab_size:
            padding_size = margs.padded_vocab_size - orig_vocab_size

            full_word_embed = torch.cat((
                orig_word_embed,
                orig_word_embed[-1].unsqueeze(0).expand(padding_size, -1)))

        # Same size!
        else:
            full_word_embed = orig_word_embed
    else:
        print("Original vocab size not specified, leaving embedding table as-is. "
              "If you've changed the tensor parallel size this could cause problems.")
245
        margs.padded_vocab_size = orig_word_embed.shape[0]
246
        full_word_embed = orig_word_embed
247
248
249
250
251

    # Split into new tensor model parallel sizes
    out_word_embed = torch.chunk(full_word_embed, args.target_tensor_parallel_size, dim=0)

    # Make models for first pipeline stage and fill in embeddings
252
    mpu.set_pipeline_model_parallel_rank(0)
253
254
255
256
    post_process = args.target_pipeline_parallel_size == 1
    models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process)
    for tp_rank, model in enumerate(models):
        model.language_model.embedding.word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
liangjing's avatar
v1  
liangjing committed
257
258
259
260
        if pos_embed is not None:
            model.language_model.embedding.position_embeddings.weight.data.copy_(pos_embed)
        else:
            assert not hasattr(model.language_model.embedding, "position_embeddings")
261
262

    # Transformer layers
xingjinliang's avatar
xingjinliang committed
263
    # -------------------
264
    total_layer_num = 0
265
266
267
    for pp_rank in range(args.target_pipeline_parallel_size):
        # For later pipeline parallel ranks, make the new models
        if pp_rank > 0:
268
            mpu.set_pipeline_model_parallel_rank(pp_rank)
269
270
271
            post_process = pp_rank == args.target_pipeline_parallel_size - 1
            models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process)

272
        for layer in range(len(models[0].language_model.encoder.layers)):
273
274
275
            msg = queue_get(f"transformer layer {total_layer_num}")

            # duplicated tensors
xingjinliang's avatar
xingjinliang committed
276
277
278
279
280
281
            input_norm_weight = msg.pop("input norm weight")
            if md.norm_has_bias:
                input_norm_bias = msg.pop("input norm bias")
            post_norm_weight = msg.pop("post norm weight")
            if md.norm_has_bias:
                post_norm_bias = msg.pop("post norm bias")
liangjing's avatar
v1  
liangjing committed
282
283
284
            if md.linear_bias:
                dense_bias = msg.pop("dense bias")
                mlp_l1_bias = msg.pop("mlp l1 bias")
285
286

            # Split up the parallel tensors
287
288
289
            qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0)
            dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1)
            mlp_l1_weight = torch.chunk(msg.pop("mlp l1 weight"), args.target_tensor_parallel_size, dim=1)
290

liangjing's avatar
v1  
liangjing committed
291
292
293
294
295
296
297
298
            # Special handling for swiglu
            if md.swiglu:
                mlp_l0_weight_W = torch.chunk(msg.pop("mlp l0 weight W"), args.target_tensor_parallel_size, dim=0)
                mlp_l0_weight_V = torch.chunk(msg.pop("mlp l0 weight V"), args.target_tensor_parallel_size, dim=0)
                mlp_l0_weight = [torch.cat(weights, dim=0) for weights in zip(mlp_l0_weight_W, mlp_l0_weight_V)]
            else:
                mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0)

xingjinliang's avatar
xingjinliang committed
299
            if md.qkv_bias:
liangjing's avatar
v1  
liangjing committed
300
                qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
xingjinliang's avatar
xingjinliang committed
301
            if md.linear_bias:
liangjing's avatar
v1  
liangjing committed
302
303
304
305
306
307
308
                if md.swiglu:
                    mlp_l0_bias_W = torch.chunk(msg.pop("mlp l0 bias W"), args.target_tensor_parallel_size, dim=0)
                    mlp_l0_bias_V = torch.chunk(msg.pop("mlp l0 bias V"), args.target_tensor_parallel_size, dim=0)
                    mlp_l0_bias = [torch.cat(bias, dim=0) for bias in zip(mlp_l0_bias_W, mlp_l0_bias_V)]
                else:
                    mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0)

309
310
311
            # Save them to the model
            for tp_rank in range(args.target_tensor_parallel_size):
                l = models[tp_rank].language_model.encoder.layers[layer]
xingjinliang's avatar
xingjinliang committed
312
313
314
                l.input_norm.weight.data.copy_(input_norm_weight)
                if md.norm_has_bias:
                    l.input_norm.bias.data.copy_(input_norm_bias)
315
316
                l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
                l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank])
xingjinliang's avatar
xingjinliang committed
317
318
319
                l.post_attention_norm.weight.data.copy_(post_norm_weight)
                if md.norm_has_bias:
                    l.post_attention_norm.bias.data.copy_(post_norm_bias)
320
321
                l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank])
                l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank])
xingjinliang's avatar
xingjinliang committed
322
                if md.qkv_bias:
liangjing's avatar
v1  
liangjing committed
323
                    l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
xingjinliang's avatar
xingjinliang committed
324
                if md.linear_bias:
liangjing's avatar
v1  
liangjing committed
325
326
327
328
                    l.self_attention.dense.bias.data.copy_(dense_bias)
                    l.mlp.dense_h_to_4h.bias.data.copy_(mlp_l0_bias[tp_rank])
                    l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias)

329
330
331
            total_layer_num = total_layer_num + 1
            check_message(msg)

332
        if post_process:
xingjinliang's avatar
xingjinliang committed
333
334
335
336
            msg = queue_get("final norm")
            final_norm_weight = msg.pop("weight")
            if md.norm_has_bias:
                final_norm_bias = msg.pop("bias")
337
            for tp_rank in range(args.target_tensor_parallel_size):
xingjinliang's avatar
xingjinliang committed
338
339
340
                models[tp_rank].language_model.encoder.final_norm.weight.data.copy_(final_norm_weight)
                if md.norm_has_bias:
                    models[tp_rank].language_model.encoder.final_norm.bias.data.copy_(final_norm_bias)
liangjing's avatar
v1  
liangjing committed
341
                if pp_rank != 0 and not md.output_layer:
342
343
                    # Copy word embeddings to final pipeline rank
                    models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
xingjinliang's avatar
xingjinliang committed
344
345
346
            del final_norm_weight
            if md.norm_has_bias:
                del final_norm_bias
347
            check_message(msg)
348

liangjing's avatar
v1  
liangjing committed
349
350
351
352
353
354
355
356
357
358
359
            if md.output_layer:
                msg = queue_get("output layer")
                if not hasattr(models[0].language_model, 'output_layer'):
                    print("ERROR: got an output layer, but model does not have one")
                    exit(1)
                output_layer_weight = torch.chunk(msg.pop("weight"), args.target_tensor_parallel_size, dim=0)
                for tp_rank in range(args.target_tensor_parallel_size):
                    models[tp_rank].language_model.output_layer.weight.data.copy_(output_layer_weight[tp_rank])
                del output_layer_weight
                check_message(msg)

360
361
            msg = queue_get()
            if msg != "done" and msg["name"] == "pooler":
362
363
364
                if not hasattr(models[0].language_model, 'pooler'):
                    print("ERROR: got a pooler, but model does not have one")
                    exit(1)
365
366
367
                print("received pooler")
                pooler_weight = msg.pop("weight")
                pooler_bias = msg.pop("bias")
368
369
370
371
372
                for tp_rank in range(args.target_tensor_parallel_size):
                    models[tp_rank].language_model.pooler.dense.weight.data.copy_(pooler_weight)
                    models[tp_rank].language_model.pooler.dense.bias.data.copy_(pooler_bias)
                del pooler_weight
                del pooler_bias
373
374
                check_message(msg)
                msg = queue_get()
375

376
            if msg != "done" and msg["name"] == "lm head":
377
378
379
                if not hasattr(models[0], 'lm_head'):
                    print("ERROR: got an lm head, but model does not have one")
                    exit(1)
380
381
382
                print("received lm head")
                lm_head_dense_weight = msg.pop("dense weight")
                lm_head_dense_bias = msg.pop("dense bias")
xingjinliang's avatar
xingjinliang committed
383
384
385
                lm_head_norm_weight = msg.pop("norm weight")
                if md.norm_has_bias:
                    lm_head_norm_bias = msg.pop("norm bias")
386
387
388
                for tp_rank in range(args.target_tensor_parallel_size):
                    models[tp_rank].lm_head.dense.weight.data.copy_(lm_head_dense_weight)
                    models[tp_rank].lm_head.dense.bias.data.copy_(lm_head_dense_bias)
xingjinliang's avatar
xingjinliang committed
389
390
391
                    models[tp_rank].lm_head.norm.weight.data.copy_(lm_head_norm_weight)
                    if md.norm_has_bias:
                        models[tp_rank].lm_head.norm.bias.data.copy_(lm_head_norm_bias)
392
393
                check_message(msg)
                msg = queue_get()
394

395
            if msg != "done" and msg["name"] == "binary head":
396
397
398
                if not hasattr(models[0], 'binary_head'):
                    print("ERROR: got a binary head, but model does not have one")
                    exit(1)
399
400
401
                print("received binary head")
                binary_head_weight = msg.pop("weight")
                binary_head_bias = msg.pop("bias")
402
403
404
                for tp_rank in range(args.target_tensor_parallel_size):
                    models[tp_rank].binary_head.weight.data.copy_(binary_head_weight)
                    models[tp_rank].binary_head.bias.data.copy_(binary_head_bias)
405
406
                check_message(msg)
                msg = queue_get()
407

408
409
            if msg != "done":
                print("ERROR: got some more data but was expecting to be done")
410
411

        for tp_rank in range(args.target_tensor_parallel_size):
412
            mpu.set_tensor_model_parallel_rank(tp_rank)
xingjinliang's avatar
xingjinliang committed
413
414
            save_checkpoint(md.iteration, [models[tp_rank]], None, None,
                            num_floating_point_operations_so_far=0)
415
    print("Done!")