initialize.py 11 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mohammad's avatar
Mohammad committed
2
3
4
5
6

"""Megatron initialization."""

import random
import os
7
import time
Mohammad's avatar
Mohammad committed
8

9
import numpy as np
Mohammad's avatar
Mohammad committed
10
import torch
Ryan Prenger's avatar
Ryan Prenger committed
11
from datetime import timedelta
Mohammad's avatar
Mohammad committed
12

13
from megatron import fused_kernels
14
15
16
from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
Mohammad's avatar
Mohammad committed
17
from megatron import mpu
18
19
from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import load_args_from_checkpoint
20
from megatron.global_vars import set_global_variables
21
22
from megatron.mpu import (set_tensor_model_parallel_rank,
                          set_tensor_model_parallel_world_size)
23
24
from megatron.model.transformer import bias_dropout_add_fused_train
from megatron.model.fused_bias_gelu import bias_gelu
25

Mohammad's avatar
Mohammad committed
26

27
def initialize_megatron(extra_args_provider=None, args_defaults={},
Raul Puri's avatar
Raul Puri committed
28
                        ignore_unknown_args=False, allow_no_cuda=False):
Mohammad's avatar
Mohammad committed
29
    """Set global variables, initialize distributed, and
Raul Puri's avatar
Raul Puri committed
30
31
32
    set autoresume and random seeds.
    `allow_no_cuda` should not be set unless using megatron for cpu only 
    data processing. In general this arg should not be set unless you know 
33
34
    what you are doing.
    Returns a function to finalize distributed env initialization 
Boris Fomitchev's avatar
Boris Fomitchev committed
35
    (optionally, only when args.lazy_mpu_init == True)
36
    """
Raul Puri's avatar
Raul Puri committed
37
38
39
    if not allow_no_cuda:
        # Make sure cuda is available.
        assert torch.cuda.is_available(), 'Megatron requires CUDA.'
Mohammad's avatar
Mohammad committed
40

41
42
43
44
45
46
47
48
49
50
    # Parse arguments
    args = parse_args(extra_args_provider, ignore_unknown_args)

    if args.use_checkpoint_args or args_defaults.get('use_checkpoint_args', False):
        assert args.load is not None, '--use-checkpoints-args requires --load argument'
        load_args_from_checkpoint(args)

    validate_args(args, args_defaults)
        
    # set global args, build tokenizer, and set adlr-autoresume,
Mohammad's avatar
Mohammad committed
51
    # tensorboard-writer, and timers.
52
    set_global_variables(args)
Mohammad's avatar
Mohammad committed
53

54
    # torch.distributed initialization
55
    def finish_mpu_init():
56
57
58
59
60
61
62
        args = get_args()
        # Pytorch distributed.
        _initialize_distributed()
        
        # Random seeds for reproducibility.
        if args.rank == 0:
            print('> setting random seeds to {} ...'.format(args.seed))
63
        _set_random_seed(args.seed, args.data_parallel_random_init)
Mohammad's avatar
Mohammad committed
64
65

    args = get_args()
66
    if  args.lazy_mpu_init:
67
        args.use_cpu_initialization=True
68
69
        # delayed initialization of DDP-related stuff
        # We only set basic DDP globals    
70
        set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
71
72
        # and return function for external DDP manager
        # to call when it has DDP initialized
73
        set_tensor_model_parallel_rank(args.rank)    
74
        return finish_mpu_init
75
    else:
76
77
        # Megatron's MPU is the master. Complete initialization right away.
        finish_mpu_init()
78

79
80
        # Autoresume.
        _init_autoresume()
mshoeybi's avatar
mshoeybi committed
81

82
83
84
        # Compile dependencies.
        _compile_dependencies()

85
86
        # No continuation function
        return None
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115


def _compile_dependencies():

    args = get_args()

    # =========================
    # Compile dataset C++ code.
    # =========================
    # TODO: move this to ninja
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
        print('> compiling dataset index builder ...')
        from megatron.data.dataset_utils import compile_helper
        compile_helper()
        print('>>> done with dataset index builder. Compilation time: {:.3f} '
              'seconds'.format(time.time() - start_time), flush=True)

    # ==================
    # Load fused kernels
    # ==================

    # Custom kernel constraints check.
    seq_len = args.seq_length
    attn_batch_size = \
        (args.num_attention_heads / args.tensor_model_parallel_size) * \
        args.micro_batch_size
    # Constraints on sequence length and attn_batch_size to enable warp based
    # optimization and upper triangular optimization (for causal mask)
116
    custom_kernel_constraint = seq_len > 16 and seq_len <=4096 and \
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        seq_len % 4 == 0 and attn_batch_size % 4 == 0
    # Print a warning.
    if not ((args.fp16 or args.bf16) and
            custom_kernel_constraint and
            args.masked_softmax_fusion):
        if args.rank == 0:
            print('WARNING: constraints for invoking optimized'
                  ' fused softmax kernel are not met. We default'
                  ' back to unfused kernel invocations.', flush=True)
    
    # Always build on rank zero first.
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
        print('> compiling and loading fused kernels ...', flush=True)
        fused_kernels.load(args)
        torch.distributed.barrier()
    else:
        torch.distributed.barrier()
        fused_kernels.load(args)
    # Simple barrier to make sure all ranks have passed the
    # compilation phase successfully before moving on to the
    # rest of the program. We think this might ensure that
    # the lock is released.
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>>> done with compiling and loading fused kernels. '
              'Compilation time: {:.3f} seconds'.format(
                  time.time() - start_time), flush=True)


Mohammad's avatar
Mohammad committed
147
148
149
150
151

def _initialize_distributed():
    """Initialize torch.distributed and mpu."""
    args = get_args()

Raul Puri's avatar
Raul Puri committed
152
    device_count = torch.cuda.device_count()
Mohammad's avatar
Mohammad committed
153
154
155
156
157
158
159
160
161
162
163
164
165
    if torch.distributed.is_initialized():

        if args.rank == 0:
            print('torch distributed is already initialized, '
                  'skipping initialization ...', flush=True)
        args.rank = torch.distributed.get_rank()
        args.world_size = torch.distributed.get_world_size()

    else:

        if args.rank == 0:
            print('> initializing torch distributed ...', flush=True)
        # Manually set the device ids.
166
        if device_count > 0:
Raul Puri's avatar
Raul Puri committed
167
            device = args.rank % device_count
168
169
170
171
172
173
            if args.local_rank is not None:
                assert args.local_rank == device, \
                    'expected local-rank to be the same as rank % device-count.'
            else:
                args.local_rank = device
            torch.cuda.set_device(device)
174
175
176
177
    # Call the init process
    torch.distributed.init_process_group(
        backend=args.distributed_backend,
        world_size=args.world_size, rank=args.rank,
178
        timeout=timedelta(minutes=10))
Mohammad's avatar
Mohammad committed
179

180
    # Set the tensor model-parallel, pipeline model-parallel, and
181
    # data-parallel communicators.
182
    if device_count > 0:
183
184
185
        if mpu.model_parallel_is_initialized():
            print('model parallel is already initialized')
        else:
186
            mpu.initialize_model_parallel(args.tensor_model_parallel_size,
187
                                          args.pipeline_model_parallel_size,
188
189
                                          args.virtual_pipeline_model_parallel_size,
                                          args.pipeline_model_parallel_split_rank)
Mohammad's avatar
Mohammad committed
190
191
192
193
194
195
196
197
198
199
200


def _init_autoresume():
    """Set autoresume start time."""
    autoresume = get_adlr_autoresume()
    if autoresume:
        torch.distributed.barrier()
        autoresume.init()
        torch.distributed.barrier()


201
def _set_random_seed(seed_, data_parallel_random_init=False):
Mohammad's avatar
Mohammad committed
202
    """Set random seed for reproducability."""
203
    if seed_ is not None and seed_ > 0:
204
205
206
207
208
        # Ensure that different pipeline MP stages get different seeds.
        seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
        # Ensure different data parallel ranks get different seeds
        if data_parallel_random_init:
            seed = seed + (10 * mpu.get_data_parallel_rank())
Mohammad's avatar
Mohammad committed
209
210
211
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
212
        if torch.cuda.device_count() > 0:
213
            mpu.model_parallel_cuda_manual_seed(seed)
Mohammad's avatar
Mohammad committed
214
215
    else:
        raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
Mohammad's avatar
Mohammad committed
216
217


218
def write_args_to_tensorboard():
Mohammad's avatar
Mohammad committed
219
220
221
222
223
    """Write arguments to tensorboard."""
    args = get_args()
    writer = get_tensorboard_writer()
    if writer:
        for arg in vars(args):
224
225
            writer.add_text(arg, str(getattr(args, arg)),
                            global_step=args.iteration)
226
227


228
def set_jit_fusion_options():
Sangkug Lym's avatar
Sangkug Lym committed
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    """Set PyTorch JIT layer fusion options."""
    # flags required to enable jit fusion kernels
    TORCH_MAJOR = int(torch.__version__.split('.')[0])
    TORCH_MINOR = int(torch.__version__.split('.')[1])
    if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
        # nvfuser
        torch._C._jit_set_profiling_executor(True)
        torch._C._jit_set_profiling_mode(True)
        torch._C._jit_override_can_fuse_on_cpu(False)
        torch._C._jit_override_can_fuse_on_gpu(False)
        torch._C._jit_set_texpr_fuser_enabled(False)
        torch._C._jit_set_nvfuser_enabled(True)
        torch._C._debug_set_autodiff_subgraph_inlining(False)
    else:
        # legacy pytorch fuser
        torch._C._jit_set_profiling_mode(False)
        torch._C._jit_set_profiling_executor(False)
        torch._C._jit_override_can_fuse_on_cpu(True)
        torch._C._jit_override_can_fuse_on_gpu(True)
248

249
250
    _warmup_jit_function()

251

252
def _warmup_jit_function():
253
254
255
    """ Compilie JIT functions before the main training steps """
    args = get_args()
    if args.bf16:
256
        dtype = torch.bfloat16
257
    elif args.fp16:
258
        dtype = torch.float16
259
    else:
260
        dtype = torch.float32
261
262

    # Warmup fused bias+gelu
263
264
    bias = torch.rand(args.ffn_hidden_size // args.tensor_model_parallel_size,
                      dtype=dtype, device='cuda')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
265
    input = torch.rand((args.seq_length, args.micro_batch_size,
266
267
268
                        args.ffn_hidden_size // args.tensor_model_parallel_size),
                       dtype=dtype, device='cuda')
    # Warmup JIT fusions with the input grad_enable state of both forward
269
    # prop and recomputation
270
271
    for bias_grad, input_grad in zip([True, True], [False, True]):
        bias.requires_grad, input.requires_grad = bias_grad, input_grad
272
        for _ in range(5):
273
274
            output = bias_gelu(bias, input)
    del bias, input, output
275
276

    # Warmup fused bias+dropout+add
Vijay Korthikanti's avatar
Vijay Korthikanti committed
277
278
279
280
    if args.sequence_parallel:
        seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
    else:
        seq_length = args.seq_length
281
    input = torch.rand((seq_length, args.micro_batch_size, args.hidden_size),
282
                       dtype=dtype, device='cuda')
283
    residual = torch.rand((seq_length, args.micro_batch_size, args.hidden_size),
284
285
286
287
                          dtype=dtype, device='cuda')
    bias = torch.rand((args.hidden_size), dtype=dtype, device='cuda').expand_as(residual)
    dropout_rate = 0.1
    # Warmup JIT fusions with the input grad_enable state of both forward
288
    # prop and recomputation
289
290
291
292
    for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
        input.requires_grad = input_grad
        bias.requires_grad = bias_grad
        residual.requires_grad = residual_grad
293
        for _ in range(5):
294
295
            output = bias_dropout_add_fused_train(input, bias, residual, dropout_rate)
    del bias, input, residual, output
296
    torch.cuda.empty_cache()