initialize.py 17.2 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mohammad's avatar
Mohammad committed
2
3

"""Megatron initialization."""
xingjinliang's avatar
xingjinliang committed
4
import logging
Mohammad's avatar
Mohammad committed
5
6
import random
import os
7
import time
xingjinliang's avatar
xingjinliang committed
8
import warnings
Mohammad's avatar
Mohammad committed
9

10
import numpy as np
Mohammad's avatar
Mohammad committed
11
import torch
Ryan Prenger's avatar
Ryan Prenger committed
12
from datetime import timedelta
Mohammad's avatar
Mohammad committed
13

xingjinliang's avatar
xingjinliang committed
14
15
16
17
from megatron.legacy import fused_kernels
from megatron.training import get_adlr_autoresume
from megatron.training import get_args
from megatron.training import get_tensorboard_writer
18
from megatron.core import mpu, tensor_parallel
xingjinliang's avatar
xingjinliang committed
19
20
21
22
23
24
25
26
27
28
29
from megatron.core.rerun_state_machine import initialize_rerun_state_machine, RerunErrorInjector, RerunDiagnostic, RerunMode
from megatron.training.arguments import parse_args, validate_args
from megatron.training.yaml_arguments import validate_yaml
from megatron.training.checkpointing import load_args_from_checkpoint
from megatron.training.global_vars import set_global_variables
from megatron.core.fusions.fused_bias_dropout import bias_dropout_add_fused_train
from megatron.core.fusions.fused_bias_gelu import bias_gelu
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu
from megatron.core.utils import get_te_version, is_te_min_version, is_torch_min_version

logger = logging.getLogger(__name__)
30

Mohammad's avatar
Mohammad committed
31

liangjing's avatar
v1  
liangjing committed
32
33
34
35
36
def initialize_megatron(
    extra_args_provider=None,
    args_defaults={},
    ignore_unknown_args=False,
    allow_no_cuda=False,
xingjinliang's avatar
xingjinliang committed
37
38
39
    skip_mpu_initialization=False,
    get_embedding_ranks=None,
    get_position_embedding_ranks=None
liangjing's avatar
v1  
liangjing committed
40
):
Mohammad's avatar
Mohammad committed
41
    """Set global variables, initialize distributed, and
Raul Puri's avatar
Raul Puri committed
42
    set autoresume and random seeds.
liangjing's avatar
v1  
liangjing committed
43
44
    `allow_no_cuda` should not be set unless using megatron for cpu only
    data processing. In general this arg should not be set unless you know
45
    what you are doing.
liangjing's avatar
v1  
liangjing committed
46
    Returns a function to finalize distributed env initialization
Boris Fomitchev's avatar
Boris Fomitchev committed
47
    (optionally, only when args.lazy_mpu_init == True)
48
    """
Raul Puri's avatar
Raul Puri committed
49
50
    if not allow_no_cuda:
        # Make sure cuda is available.
liangjing's avatar
v1  
liangjing committed
51
        assert torch.cuda.is_available(), "Megatron requires CUDA."
Mohammad's avatar
Mohammad committed
52

53
54
55
    # Parse arguments
    args = parse_args(extra_args_provider, ignore_unknown_args)

xingjinliang's avatar
xingjinliang committed
56
57
58
59
60
61
    # Prep for checkpoint conversion.
    if args.ckpt_convert_format is not None:
        assert args.ckpt_convert_save is not None
        assert args.load is not None
        args.exit_on_missing_checkpoint = True

liangjing's avatar
v1  
liangjing committed
62
    if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
xingjinliang's avatar
xingjinliang committed
63
        assert args.load is not None, "--use-checkpoint-args requires --load argument"
64
65
        load_args_from_checkpoint(args)

xingjinliang's avatar
xingjinliang committed
66
67
68
69
70
    if args.yaml_cfg is not None:
        args = validate_yaml(args, args_defaults)
    else:
        validate_args(args, args_defaults)

liangjing's avatar
v1  
liangjing committed
71

72
    # set global args, build tokenizer, and set adlr-autoresume,
Mohammad's avatar
Mohammad committed
73
    # tensorboard-writer, and timers.
74
    set_global_variables(args)
Mohammad's avatar
Mohammad committed
75

xingjinliang's avatar
xingjinliang committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    # set logging level
    setup_logging()

    # init rerun state
    def state_save_func():
        return {
            'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()
        }
    
    def state_restore_func(state_dict):
        if state_dict['rng_tracker_states']:
            tensor_parallel.get_cuda_rng_tracker().set_states(state_dict['rng_tracker_states'])

    args = get_args()
    initialize_rerun_state_machine(
        state_save_func=state_save_func,
        state_restore_func=state_restore_func,
        mode=RerunMode(args.rerun_mode),
        error_injector=RerunErrorInjector(
            error_injection_rate=args.error_injection_rate,
            error_injection_type=RerunDiagnostic(args.error_injection_type),
        ),
    )

100
    # torch.distributed initialization
101
    def finish_mpu_init():
102
103
        args = get_args()
        # Pytorch distributed.
xingjinliang's avatar
xingjinliang committed
104
        _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks)
liangjing's avatar
v1  
liangjing committed
105

106
107
        # Random seeds for reproducibility.
        if args.rank == 0:
liangjing's avatar
v1  
liangjing committed
108
            print("> setting random seeds to {} ...".format(args.seed))
109
        _set_random_seed(args.seed, args.data_parallel_random_init)
Mohammad's avatar
Mohammad committed
110

xingjinliang's avatar
xingjinliang committed
111
112
113
    if skip_mpu_initialization:
        return None

Mohammad's avatar
Mohammad committed
114
    args = get_args()
liangjing's avatar
v1  
liangjing committed
115
    if args.lazy_mpu_init:
116
        # TODO is this still a necessary option?
liangjing's avatar
v1  
liangjing committed
117
        args.use_cpu_initialization = True
118
        # delayed initialization of DDP-related stuff
119
120
        # We only set basic DDP globals
        mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
121
122
        # and return function for external DDP manager
        # to call when it has DDP initialized
123
        mpu.set_tensor_model_parallel_rank(args.rank)
124
        return finish_mpu_init
125
    else:
126
127
        # Megatron's MPU is the master. Complete initialization right away.
        finish_mpu_init()
128

129
130
        # Autoresume.
        _init_autoresume()
mshoeybi's avatar
mshoeybi committed
131

132
133
134
        # Compile dependencies.
        _compile_dependencies()

xingjinliang's avatar
xingjinliang committed
135
136
137
138
        if args.tp_comm_overlap:
            #TODO: Should this be activated with just decoder-tp-comm-overlap too?
           _initialize_tp_communicators()

139
140
        # No continuation function
        return None
141
142
143
144
145
146
147
148
149
150
151
152


def _compile_dependencies():

    args = get_args()

    # =========================
    # Compile dataset C++ code.
    # =========================
    # TODO: move this to ninja
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
liangjing's avatar
v1  
liangjing committed
153
        print("> compiling dataset index builder ...")
xingjinliang's avatar
xingjinliang committed
154
        from megatron.core.datasets.utils import compile_helpers
liangjing's avatar
v1  
liangjing committed
155

xingjinliang's avatar
xingjinliang committed
156
        compile_helpers()
liangjing's avatar
v1  
liangjing committed
157
158
159
160
161
        print(
            ">>> done with dataset index builder. Compilation time: {:.3f} "
            "seconds".format(time.time() - start_time),
            flush=True,
        )
162
163
164
165
166
167
168

    # ==================
    # Load fused kernels
    # ==================

    # Custom kernel constraints check.
    seq_len = args.seq_length
liangjing's avatar
v1  
liangjing committed
169
170
171
    attn_batch_size = (
        args.num_attention_heads / args.tensor_model_parallel_size
    ) * args.micro_batch_size
172
173
    # Constraints on sequence length and attn_batch_size to enable warp based
    # optimization and upper triangular optimization (for causal mask)
liangjing's avatar
v1  
liangjing committed
174
175
176
177
178
179
    custom_kernel_constraint = (
        seq_len > 16
        and seq_len <= 16384
        and seq_len % 4 == 0
        and attn_batch_size % 4 == 0
    )
180
    # Print a warning.
liangjing's avatar
v1  
liangjing committed
181
182
183
184
185
    if not (
        (args.fp16 or args.bf16)
        and custom_kernel_constraint
        and args.masked_softmax_fusion
    ):
186
        if args.rank == 0:
liangjing's avatar
v1  
liangjing committed
187
188
189
190
191
192
193
            print(
                "WARNING: constraints for invoking optimized"
                " fused softmax kernel are not met. We default"
                " back to unfused kernel invocations.",
                flush=True,
            )

194
195
196
    # Always build on rank zero first.
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
liangjing's avatar
v1  
liangjing committed
197
198
        print("> compiling and loading fused kernels ...", flush=True)
        #fused_kernels.load(args)
199
200
201
        torch.distributed.barrier()
    else:
        torch.distributed.barrier()
liangjing's avatar
v1  
liangjing committed
202
        #fused_kernels.load(args)
203
204
205
206
207
208
    # Simple barrier to make sure all ranks have passed the
    # compilation phase successfully before moving on to the
    # rest of the program. We think this might ensure that
    # the lock is released.
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
liangjing's avatar
v1  
liangjing committed
209
210
211
212
213
        print(
            ">>> done with compiling and loading fused kernels. "
            "Compilation time: {:.3f} seconds".format(time.time() - start_time),
            flush=True,
        )
214

xingjinliang's avatar
xingjinliang committed
215
216
217
218
219
220
def _initialize_tp_communicators():
    """ initializing the communicators with user buffers for high-performance tensor-model-parallel
        communication overlap """

    try:
       import yaml
Mohammad's avatar
Mohammad committed
221

xingjinliang's avatar
xingjinliang committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
       import transformer_engine
       from transformer_engine.pytorch import module as te_module

    except ImportError:
       raise RuntimeError("Tensor Parallel Communication/GEMM Overlap optimization needs 'yaml' and "
             "'transformer_engine' packages")

    args = get_args()

    if args.tp_comm_overlap_cfg is not None:
       with open(args.tp_comm_overlap_cfg,"r") as stream:
          ub_cfgs = yaml.safe_load(stream)
    else:
       ub_cfgs = {}

    if getattr(args, 'decoder_tp_comm_overlap', False):
        input_shape = [(args.decoder_seq_length * args.micro_batch_size) // args.context_parallel_size , args.hidden_size]
    else:
        input_shape = [(args.seq_length * args.micro_batch_size) // args.context_parallel_size , args.hidden_size]

    if is_te_min_version("1.9.0"):
        # The process group with the target bootstrap backend is created in Transformer Engine.
        te_module.base.initialize_ub(shape = input_shape, tp_size = args.tensor_model_parallel_size,
                                     use_fp8 = (args.fp8 is not None) , ub_cfgs = ub_cfgs,
                                     bootstrap_backend = args.tp_comm_bootstrap_backend)
    else:
        if args.tp_comm_bootstrap_backend != 'mpi':
            warnings.warn(
                f"Transformer Engine v{get_te_version()} supports only MPI bootstrap backend."
            )
        # Create a MPI process group to help with TP communication overlap bootstrap.
        torch.distributed.new_group(backend='mpi')
    
        te_module.base.initialize_ub(shape = input_shape, tp_size = args.tensor_model_parallel_size,
                                     use_fp8 = (args.fp8 is not None) , ub_cfgs = ub_cfgs)

def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
259
    """Initialize torch.distributed and core model parallel."""
Mohammad's avatar
Mohammad committed
260
261
    args = get_args()

Raul Puri's avatar
Raul Puri committed
262
    device_count = torch.cuda.device_count()
Mohammad's avatar
Mohammad committed
263
264
265
    if torch.distributed.is_initialized():

        if args.rank == 0:
liangjing's avatar
v1  
liangjing committed
266
267
268
269
270
            print(
                "torch distributed is already initialized, "
                "skipping initialization ...",
                flush=True,
            )
xingjinliang's avatar
xingjinliang committed
271
272
        args.rank = torch.distributed.get_rank()
        args.world_size = torch.distributed.get_world_size()
Mohammad's avatar
Mohammad committed
273
274
275

    else:
        if args.rank == 0:
liangjing's avatar
v1  
liangjing committed
276
            print("> initializing torch distributed ...", flush=True)
Mohammad's avatar
Mohammad committed
277
        # Manually set the device ids.
278
        if device_count > 0:
xingjinliang's avatar
xingjinliang committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
            torch.cuda.set_device(args.local_rank)
            device_id = torch.device(f'cuda:{args.local_rank}')
        else:
            device_id = None

        # Call the init process
        init_process_group_kwargs = {
            'backend' : args.distributed_backend,
            'world_size': args.world_size,
            'rank': args.rank,
            'init_method': args.dist_url,
            'timeout': timedelta(minutes=args.distributed_timeout_minutes),
        }

        torch.distributed.init_process_group(**init_process_group_kwargs)
Mohammad's avatar
Mohammad committed
294

295
    # Set the tensor model-parallel, pipeline model-parallel, and
296
    # data-parallel communicators.
297
    if device_count > 0:
298
        if mpu.model_parallel_is_initialized():
liangjing's avatar
v1  
liangjing committed
299
            print("model parallel is already initialized")
300
        else:
liangjing's avatar
v1  
liangjing committed
301
302
303
304
305
            mpu.initialize_model_parallel(
                args.tensor_model_parallel_size,
                args.pipeline_model_parallel_size,
                args.virtual_pipeline_model_parallel_size,
                args.pipeline_model_parallel_split_rank,
xingjinliang's avatar
xingjinliang committed
306
307
308
309
310
311
312
313
314
315
316
317
                context_parallel_size=args.context_parallel_size,
                hierarchical_context_parallel_sizes=args.hierarchical_context_parallel_sizes,
                expert_model_parallel_size=args.expert_model_parallel_size,
                num_distributed_optimizer_instances=args.num_distributed_optimizer_instances,
                expert_tensor_parallel_size=args.expert_tensor_parallel_size,
                distributed_timeout_minutes=args.distributed_timeout_minutes,
                nccl_communicator_config_path=args.nccl_communicator_config_path,
                order='tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-pp-dp',
                encoder_tensor_model_parallel_size=args.encoder_tensor_model_parallel_size,
                encoder_pipeline_model_parallel_size=args.encoder_pipeline_model_parallel_size,
                get_embedding_ranks=get_embedding_ranks,
                get_position_embedding_ranks=get_position_embedding_ranks,
liangjing's avatar
v1  
liangjing committed
318
            )
319
            if args.rank == 0:
liangjing's avatar
v1  
liangjing committed
320
321
322
323
324
325
326
327
                print(
                    f"> initialized tensor model parallel with size "
                    f"{mpu.get_tensor_model_parallel_world_size()}"
                )
                print(
                    f"> initialized pipeline model parallel with size "
                    f"{mpu.get_pipeline_model_parallel_world_size()}"
                )
Mohammad's avatar
Mohammad committed
328
329
330
331
332
333
334
335
336
337
338


def _init_autoresume():
    """Set autoresume start time."""
    autoresume = get_adlr_autoresume()
    if autoresume:
        torch.distributed.barrier()
        autoresume.init()
        torch.distributed.barrier()


339
def _set_random_seed(seed_, data_parallel_random_init=False):
Mohammad's avatar
Mohammad committed
340
    """Set random seed for reproducability."""
341
    if seed_ is not None and seed_ > 0:
342
343
344
345
346
        # Ensure that different pipeline MP stages get different seeds.
        seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
        # Ensure different data parallel ranks get different seeds
        if data_parallel_random_init:
            seed = seed + (10 * mpu.get_data_parallel_rank())
Mohammad's avatar
Mohammad committed
347
348
349
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
350
        if torch.cuda.device_count() > 0:
351
            tensor_parallel.model_parallel_cuda_manual_seed(seed)
Mohammad's avatar
Mohammad committed
352
    else:
xingjinliang's avatar
xingjinliang committed
353
        raise ValueError("Seed ({}) should be a positive integer.".format(seed_))
Mohammad's avatar
Mohammad committed
354
355


356
def write_args_to_tensorboard():
Mohammad's avatar
Mohammad committed
357
358
359
360
361
    """Write arguments to tensorboard."""
    args = get_args()
    writer = get_tensorboard_writer()
    if writer:
        for arg in vars(args):
liangjing's avatar
v1  
liangjing committed
362
            writer.add_text(arg, str(getattr(args, arg)), global_step=args.iteration)
363
364


365
def set_jit_fusion_options():
Sangkug Lym's avatar
Sangkug Lym committed
366
367
    """Set PyTorch JIT layer fusion options."""
    # flags required to enable jit fusion kernels
xingjinliang's avatar
xingjinliang committed
368
369
370
    if is_torch_min_version("2.2.0a0"):
        pass  # we're using torch.compile for jit fusion
    elif is_torch_min_version("1.10.0a0"):
Sangkug Lym's avatar
Sangkug Lym committed
371
372
373
374
375
376
        # nvfuser
        torch._C._jit_set_profiling_executor(True)
        torch._C._jit_set_profiling_mode(True)
        torch._C._jit_override_can_fuse_on_cpu(False)
        torch._C._jit_override_can_fuse_on_gpu(False)
        torch._C._jit_set_texpr_fuser_enabled(False)
xingjinliang's avatar
xingjinliang committed
377
        torch._C._jit_set_nvfuser_enabled(True)
Sangkug Lym's avatar
Sangkug Lym committed
378
379
380
381
382
383
384
        torch._C._debug_set_autodiff_subgraph_inlining(False)
    else:
        # legacy pytorch fuser
        torch._C._jit_set_profiling_mode(False)
        torch._C._jit_set_profiling_executor(False)
        torch._C._jit_override_can_fuse_on_cpu(True)
        torch._C._jit_override_can_fuse_on_gpu(True)
385

386
387
    _warmup_jit_function()

388

389
def _warmup_jit_function():
liangjing's avatar
v1  
liangjing committed
390
    """Compilie JIT functions before the main training steps"""
391
392
    args = get_args()
    if args.bf16:
393
        dtype = torch.bfloat16
394
    elif args.fp16:
395
        dtype = torch.float16
396
    else:
397
        dtype = torch.float32
398
399

    # Warmup fused bias+gelu
liangjing's avatar
v1  
liangjing committed
400
401
402
403
404
405
406
    bias = torch.rand(
        args.ffn_hidden_size // args.tensor_model_parallel_size,
        dtype=dtype,
        device="cuda",
    )
    input = torch.rand(
        (
xingjinliang's avatar
xingjinliang committed
407
            args.seq_length // args.context_parallel_size,
liangjing's avatar
v1  
liangjing committed
408
409
410
411
412
413
            args.micro_batch_size,
            args.ffn_hidden_size // args.tensor_model_parallel_size,
        ),
        dtype=dtype,
        device="cuda",
    )
414
    # Warmup JIT fusions with the input grad_enable state of both forward
415
    # prop and recomputation
416
417
    for bias_grad, input_grad in zip([True, True], [False, True]):
        bias.requires_grad, input.requires_grad = bias_grad, input_grad
418
        for _ in range(5):
xingjinliang's avatar
xingjinliang committed
419
420
421
422
            if args.swiglu:
                output = bias_swiglu(input, bias)
            else:
                output = bias_gelu(bias, input)
423
    del bias, input, output
424
425

    # Warmup fused bias+dropout+add
Vijay Korthikanti's avatar
Vijay Korthikanti committed
426
427
428
429
    if args.sequence_parallel:
        seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
    else:
        seq_length = args.seq_length
liangjing's avatar
v1  
liangjing committed
430
    input = torch.rand(
xingjinliang's avatar
xingjinliang committed
431
        (seq_length // args.context_parallel_size, args.micro_batch_size, args.hidden_size),
liangjing's avatar
v1  
liangjing committed
432
433
434
435
        dtype=dtype,
        device="cuda",
    )
    residual = torch.rand(
xingjinliang's avatar
xingjinliang committed
436
        (seq_length // args.context_parallel_size, args.micro_batch_size, args.hidden_size),
liangjing's avatar
v1  
liangjing committed
437
438
439
440
441
442
        dtype=dtype,
        device="cuda",
    )
    bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as(
        residual
    )
443
444
    dropout_rate = 0.1
    # Warmup JIT fusions with the input grad_enable state of both forward
445
    # prop and recomputation
liangjing's avatar
v1  
liangjing committed
446
447
448
    for input_grad, bias_grad, residual_grad in zip(
        [False, True], [True, True], [True, True]
    ):
449
450
451
        input.requires_grad = input_grad
        bias.requires_grad = bias_grad
        residual.requires_grad = residual_grad
452
        for _ in range(5):
xingjinliang's avatar
xingjinliang committed
453
            output = bias_dropout_add_fused_train([input, bias], residual, dropout_rate)
454
    del bias, input, residual, output
455
    torch.cuda.empty_cache()
xingjinliang's avatar
xingjinliang committed
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478


def setup_logging() -> None:
    """ Sets the default logging level based on cmdline args and env vars.

    Precedence:
    1. Command line argument `--logging-level`
    2. Env var `MEGATRON_LOGGING_LEVEL`
    3. Default logging level (INFO)

    Returns: None
    """
    args = get_args()
    logging_level = None
    env_logging_level = os.getenv('MEGATRON_LOGGING_LEVEL', None)
    if env_logging_level is not None:
        logging_level = int(env_logging_level)
    if args.logging_level is not None:
        logging_level = args.logging_level

    if logging_level is not None:
        logger.info(f'Setting logging level to {logging_level}')
        logging.getLogger().setLevel(logging_level)