initialize.py 11.7 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mohammad's avatar
Mohammad committed
2
3
4
5
6

"""Megatron initialization."""

import random
import os
7
import time
Mohammad's avatar
Mohammad committed
8

9
import numpy as np
Mohammad's avatar
Mohammad committed
10
import torch
Ryan Prenger's avatar
Ryan Prenger committed
11
from datetime import timedelta
Mohammad's avatar
Mohammad committed
12

13
from megatron import fused_kernels
14
15
16
from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
17
from megatron.core import mpu, tensor_parallel
liangjing's avatar
v1  
liangjing committed
18
from megatron.arguments import parse_args, validate_args
19
from megatron.checkpointing import load_args_from_checkpoint
20
from megatron.global_vars import set_global_variables
21
22
from megatron.model.transformer import bias_dropout_add_fused_train
from megatron.model.fused_bias_gelu import bias_gelu
23

Mohammad's avatar
Mohammad committed
24

liangjing's avatar
v1  
liangjing committed
25
26
27
28
29
30
def initialize_megatron(
    extra_args_provider=None,
    args_defaults={},
    ignore_unknown_args=False,
    allow_no_cuda=False,
):
Mohammad's avatar
Mohammad committed
31
    """Set global variables, initialize distributed, and
Raul Puri's avatar
Raul Puri committed
32
    set autoresume and random seeds.
liangjing's avatar
v1  
liangjing committed
33
34
    `allow_no_cuda` should not be set unless using megatron for cpu only
    data processing. In general this arg should not be set unless you know
35
    what you are doing.
liangjing's avatar
v1  
liangjing committed
36
    Returns a function to finalize distributed env initialization
Boris Fomitchev's avatar
Boris Fomitchev committed
37
    (optionally, only when args.lazy_mpu_init == True)
38
    """
Raul Puri's avatar
Raul Puri committed
39
40
    if not allow_no_cuda:
        # Make sure cuda is available.
liangjing's avatar
v1  
liangjing committed
41
        assert torch.cuda.is_available(), "Megatron requires CUDA."
Mohammad's avatar
Mohammad committed
42

43
44
45
    # Parse arguments
    args = parse_args(extra_args_provider, ignore_unknown_args)

liangjing's avatar
v1  
liangjing committed
46
47
    if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
        assert args.load is not None, "--use-checkpoints-args requires --load argument"
48
49
50
        load_args_from_checkpoint(args)

    validate_args(args, args_defaults)
liangjing's avatar
v1  
liangjing committed
51

52
    # set global args, build tokenizer, and set adlr-autoresume,
Mohammad's avatar
Mohammad committed
53
    # tensorboard-writer, and timers.
54
    set_global_variables(args)
Mohammad's avatar
Mohammad committed
55

56
    # torch.distributed initialization
57
    def finish_mpu_init():
58
59
60
        args = get_args()
        # Pytorch distributed.
        _initialize_distributed()
liangjing's avatar
v1  
liangjing committed
61

62
63
        # Random seeds for reproducibility.
        if args.rank == 0:
liangjing's avatar
v1  
liangjing committed
64
            print("> setting random seeds to {} ...".format(args.seed))
65
        _set_random_seed(args.seed, args.data_parallel_random_init)
Mohammad's avatar
Mohammad committed
66
67

    args = get_args()
liangjing's avatar
v1  
liangjing committed
68
    if args.lazy_mpu_init:
69
        # TODO is this still a necessary option?
liangjing's avatar
v1  
liangjing committed
70
        args.use_cpu_initialization = True
71
        # delayed initialization of DDP-related stuff
72
73
        # We only set basic DDP globals
        mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
74
75
        # and return function for external DDP manager
        # to call when it has DDP initialized
76
        mpu.set_tensor_model_parallel_rank(args.rank)
77
        return finish_mpu_init
78
    else:
79
80
        # Megatron's MPU is the master. Complete initialization right away.
        finish_mpu_init()
81

82
83
        # Autoresume.
        _init_autoresume()
mshoeybi's avatar
mshoeybi committed
84

85
86
87
        # Compile dependencies.
        _compile_dependencies()

88
89
        # No continuation function
        return None
90
91
92
93
94
95
96
97
98
99
100
101


def _compile_dependencies():

    args = get_args()

    # =========================
    # Compile dataset C++ code.
    # =========================
    # TODO: move this to ninja
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
liangjing's avatar
v1  
liangjing committed
102
        print("> compiling dataset index builder ...")
103
        from megatron.data.dataset_utils import compile_helper
liangjing's avatar
v1  
liangjing committed
104

105
        compile_helper()
liangjing's avatar
v1  
liangjing committed
106
107
108
109
110
        print(
            ">>> done with dataset index builder. Compilation time: {:.3f} "
            "seconds".format(time.time() - start_time),
            flush=True,
        )
111
112
113
114
115
116
117

    # ==================
    # Load fused kernels
    # ==================

    # Custom kernel constraints check.
    seq_len = args.seq_length
liangjing's avatar
v1  
liangjing committed
118
119
120
    attn_batch_size = (
        args.num_attention_heads / args.tensor_model_parallel_size
    ) * args.micro_batch_size
121
122
    # Constraints on sequence length and attn_batch_size to enable warp based
    # optimization and upper triangular optimization (for causal mask)
liangjing's avatar
v1  
liangjing committed
123
124
125
126
127
128
    custom_kernel_constraint = (
        seq_len > 16
        and seq_len <= 16384
        and seq_len % 4 == 0
        and attn_batch_size % 4 == 0
    )
129
    # Print a warning.
liangjing's avatar
v1  
liangjing committed
130
131
132
133
134
    if not (
        (args.fp16 or args.bf16)
        and custom_kernel_constraint
        and args.masked_softmax_fusion
    ):
135
        if args.rank == 0:
liangjing's avatar
v1  
liangjing committed
136
137
138
139
140
141
142
            print(
                "WARNING: constraints for invoking optimized"
                " fused softmax kernel are not met. We default"
                " back to unfused kernel invocations.",
                flush=True,
            )

143
144
145
    # Always build on rank zero first.
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
liangjing's avatar
v1  
liangjing committed
146
147
        print("> compiling and loading fused kernels ...", flush=True)
        #fused_kernels.load(args)
148
149
150
        torch.distributed.barrier()
    else:
        torch.distributed.barrier()
liangjing's avatar
v1  
liangjing committed
151
        #fused_kernels.load(args)
152
153
154
155
156
157
    # Simple barrier to make sure all ranks have passed the
    # compilation phase successfully before moving on to the
    # rest of the program. We think this might ensure that
    # the lock is released.
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
liangjing's avatar
v1  
liangjing committed
158
159
160
161
162
        print(
            ">>> done with compiling and loading fused kernels. "
            "Compilation time: {:.3f} seconds".format(time.time() - start_time),
            flush=True,
        )
163

Mohammad's avatar
Mohammad committed
164
165

def _initialize_distributed():
166
    """Initialize torch.distributed and core model parallel."""
Mohammad's avatar
Mohammad committed
167
168
    args = get_args()

Raul Puri's avatar
Raul Puri committed
169
    device_count = torch.cuda.device_count()
Mohammad's avatar
Mohammad committed
170
171
172
    if torch.distributed.is_initialized():

        if args.rank == 0:
liangjing's avatar
v1  
liangjing committed
173
174
175
176
177
178
179
            print(
                "torch distributed is already initialized, "
                "skipping initialization ...",
                flush=True,
            )
        #args.rank = torch.distributed.get_rank()
        #args.world_size = torch.distributed.get_world_size()
Mohammad's avatar
Mohammad committed
180
181
182
183

    else:

        if args.rank == 0:
liangjing's avatar
v1  
liangjing committed
184
            print("> initializing torch distributed ...", flush=True)
Mohammad's avatar
Mohammad committed
185
        # Manually set the device ids.
186
        if device_count > 0:
Raul Puri's avatar
Raul Puri committed
187
            device = args.rank % device_count
188
            if args.local_rank is not None:
liangjing's avatar
v1  
liangjing committed
189
190
191
                assert (
                    args.local_rank == device
                ), "expected local-rank to be the same as rank % device-count."
192
193
194
            else:
                args.local_rank = device
            torch.cuda.set_device(device)
195
196
197
    # Call the init process
    torch.distributed.init_process_group(
        backend=args.distributed_backend,
liangjing's avatar
v1  
liangjing committed
198
199
200
201
        world_size=args.world_size,
        rank=args.rank,init_method=args.dist_url,
        timeout=timedelta(minutes=args.distributed_timeout_minutes),
    )
Mohammad's avatar
Mohammad committed
202

203
    # Set the tensor model-parallel, pipeline model-parallel, and
204
    # data-parallel communicators.
205
    if device_count > 0:
206
        if mpu.model_parallel_is_initialized():
liangjing's avatar
v1  
liangjing committed
207
            print("model parallel is already initialized")
208
        else:
liangjing's avatar
v1  
liangjing committed
209
210
211
212
213
214
215
            mpu.initialize_model_parallel(
                args.tensor_model_parallel_size,
                args.pipeline_model_parallel_size,
                args.virtual_pipeline_model_parallel_size,
                args.pipeline_model_parallel_split_rank,
                args.fp8 is not None,
            )
216
            if args.rank == 0:
liangjing's avatar
v1  
liangjing committed
217
218
219
220
221
222
223
224
                print(
                    f"> initialized tensor model parallel with size "
                    f"{mpu.get_tensor_model_parallel_world_size()}"
                )
                print(
                    f"> initialized pipeline model parallel with size "
                    f"{mpu.get_pipeline_model_parallel_world_size()}"
                )
Mohammad's avatar
Mohammad committed
225
226
227
228
229
230
231
232
233
234
235


def _init_autoresume():
    """Set autoresume start time."""
    autoresume = get_adlr_autoresume()
    if autoresume:
        torch.distributed.barrier()
        autoresume.init()
        torch.distributed.barrier()


236
def _set_random_seed(seed_, data_parallel_random_init=False):
Mohammad's avatar
Mohammad committed
237
    """Set random seed for reproducability."""
238
    if seed_ is not None and seed_ > 0:
239
240
241
242
243
        # Ensure that different pipeline MP stages get different seeds.
        seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
        # Ensure different data parallel ranks get different seeds
        if data_parallel_random_init:
            seed = seed + (10 * mpu.get_data_parallel_rank())
Mohammad's avatar
Mohammad committed
244
245
246
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
247
        if torch.cuda.device_count() > 0:
248
            tensor_parallel.model_parallel_cuda_manual_seed(seed)
Mohammad's avatar
Mohammad committed
249
    else:
liangjing's avatar
v1  
liangjing committed
250
        raise ValueError("Seed ({}) should be a positive integer.".format(seed))
Mohammad's avatar
Mohammad committed
251
252


253
def write_args_to_tensorboard():
Mohammad's avatar
Mohammad committed
254
255
256
257
258
    """Write arguments to tensorboard."""
    args = get_args()
    writer = get_tensorboard_writer()
    if writer:
        for arg in vars(args):
liangjing's avatar
v1  
liangjing committed
259
            writer.add_text(arg, str(getattr(args, arg)), global_step=args.iteration)
260
261


262
def set_jit_fusion_options():
Sangkug Lym's avatar
Sangkug Lym committed
263
264
    """Set PyTorch JIT layer fusion options."""
    # flags required to enable jit fusion kernels
liangjing's avatar
v1  
liangjing committed
265
266
    TORCH_MAJOR = int(torch.__version__.split(".")[0])
    TORCH_MINOR = int(torch.__version__.split(".")[1])
Sangkug Lym's avatar
Sangkug Lym committed
267
268
269
270
271
272
273
    if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
        # nvfuser
        torch._C._jit_set_profiling_executor(True)
        torch._C._jit_set_profiling_mode(True)
        torch._C._jit_override_can_fuse_on_cpu(False)
        torch._C._jit_override_can_fuse_on_gpu(False)
        torch._C._jit_set_texpr_fuser_enabled(False)
liangjing's avatar
v1  
liangjing committed
274
        torch._C._jit_set_nvfuser_enabled(False)
Sangkug Lym's avatar
Sangkug Lym committed
275
276
277
278
279
280
281
        torch._C._debug_set_autodiff_subgraph_inlining(False)
    else:
        # legacy pytorch fuser
        torch._C._jit_set_profiling_mode(False)
        torch._C._jit_set_profiling_executor(False)
        torch._C._jit_override_can_fuse_on_cpu(True)
        torch._C._jit_override_can_fuse_on_gpu(True)
282

283
284
    _warmup_jit_function()

285

286
def _warmup_jit_function():
liangjing's avatar
v1  
liangjing committed
287
    """Compilie JIT functions before the main training steps"""
288
289
    args = get_args()
    if args.bf16:
290
        dtype = torch.bfloat16
291
    elif args.fp16:
292
        dtype = torch.float16
293
    else:
294
        dtype = torch.float32
295
296

    # Warmup fused bias+gelu
liangjing's avatar
v1  
liangjing committed
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    bias = torch.rand(
        args.ffn_hidden_size // args.tensor_model_parallel_size,
        dtype=dtype,
        device="cuda",
    )
    input = torch.rand(
        (
            args.seq_length,
            args.micro_batch_size,
            args.ffn_hidden_size // args.tensor_model_parallel_size,
        ),
        dtype=dtype,
        device="cuda",
    )
311
    # Warmup JIT fusions with the input grad_enable state of both forward
312
    # prop and recomputation
313
314
    for bias_grad, input_grad in zip([True, True], [False, True]):
        bias.requires_grad, input.requires_grad = bias_grad, input_grad
315
        for _ in range(5):
316
317
            output = bias_gelu(bias, input)
    del bias, input, output
318
319

    # Warmup fused bias+dropout+add
Vijay Korthikanti's avatar
Vijay Korthikanti committed
320
321
322
323
    if args.sequence_parallel:
        seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
    else:
        seq_length = args.seq_length
liangjing's avatar
v1  
liangjing committed
324
325
326
327
328
329
330
331
332
333
334
335
336
    input = torch.rand(
        (seq_length, args.micro_batch_size, args.hidden_size),
        dtype=dtype,
        device="cuda",
    )
    residual = torch.rand(
        (seq_length, args.micro_batch_size, args.hidden_size),
        dtype=dtype,
        device="cuda",
    )
    bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as(
        residual
    )
337
338
    dropout_rate = 0.1
    # Warmup JIT fusions with the input grad_enable state of both forward
339
    # prop and recomputation
liangjing's avatar
v1  
liangjing committed
340
341
342
    for input_grad, bias_grad, residual_grad in zip(
        [False, True], [True, True], [True, True]
    ):
343
344
345
        input.requires_grad = input_grad
        bias.requires_grad = bias_grad
        residual.requires_grad = residual_grad
346
        for _ in range(5):
347
348
            output = bias_dropout_add_fused_train(input, bias, residual, dropout_rate)
    del bias, input, residual, output
349
    torch.cuda.empty_cache()