initialize.py 11.5 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mohammad's avatar
Mohammad committed
2
3
4
5
6

"""Megatron initialization."""

import random
import os
7
import time
Mohammad's avatar
Mohammad committed
8

9
import numpy as np
Mohammad's avatar
Mohammad committed
10
import torch
Ryan Prenger's avatar
Ryan Prenger committed
11
from datetime import timedelta
Mohammad's avatar
Mohammad committed
12

13
from megatron import fused_kernels
14
15
16
from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
17
from megatron.core import mpu, tensor_parallel
18
19
from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import load_args_from_checkpoint
20
from megatron.global_vars import set_global_variables
21
22
from megatron.model.transformer import bias_dropout_add_fused_train
from megatron.model.fused_bias_gelu import bias_gelu
23

Mohammad's avatar
Mohammad committed
24

25
def initialize_megatron(extra_args_provider=None, args_defaults={},
Raul Puri's avatar
Raul Puri committed
26
                        ignore_unknown_args=False, allow_no_cuda=False):
Mohammad's avatar
Mohammad committed
27
    """Set global variables, initialize distributed, and
Raul Puri's avatar
Raul Puri committed
28
29
30
    set autoresume and random seeds.
    `allow_no_cuda` should not be set unless using megatron for cpu only 
    data processing. In general this arg should not be set unless you know 
31
32
    what you are doing.
    Returns a function to finalize distributed env initialization 
Boris Fomitchev's avatar
Boris Fomitchev committed
33
    (optionally, only when args.lazy_mpu_init == True)
34
    """
Raul Puri's avatar
Raul Puri committed
35
36
37
    if not allow_no_cuda:
        # Make sure cuda is available.
        assert torch.cuda.is_available(), 'Megatron requires CUDA.'
Mohammad's avatar
Mohammad committed
38

39
40
41
42
43
44
45
46
47
48
    # Parse arguments
    args = parse_args(extra_args_provider, ignore_unknown_args)

    if args.use_checkpoint_args or args_defaults.get('use_checkpoint_args', False):
        assert args.load is not None, '--use-checkpoints-args requires --load argument'
        load_args_from_checkpoint(args)

    validate_args(args, args_defaults)
        
    # set global args, build tokenizer, and set adlr-autoresume,
Mohammad's avatar
Mohammad committed
49
    # tensorboard-writer, and timers.
50
    set_global_variables(args)
Mohammad's avatar
Mohammad committed
51

52
    # torch.distributed initialization
53
    def finish_mpu_init():
54
55
56
57
58
59
60
        args = get_args()
        # Pytorch distributed.
        _initialize_distributed()
        
        # Random seeds for reproducibility.
        if args.rank == 0:
            print('> setting random seeds to {} ...'.format(args.seed))
61
        _set_random_seed(args.seed, args.data_parallel_random_init)
Mohammad's avatar
Mohammad committed
62
63

    args = get_args()
64
    if  args.lazy_mpu_init:
65
        # TODO is this still a necessary option?
66
        args.use_cpu_initialization=True
67
        # delayed initialization of DDP-related stuff
68
69
        # We only set basic DDP globals
        mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
70
71
        # and return function for external DDP manager
        # to call when it has DDP initialized
72
        mpu.set_tensor_model_parallel_rank(args.rank)
73
        return finish_mpu_init
74
    else:
75
76
        # Megatron's MPU is the master. Complete initialization right away.
        finish_mpu_init()
77

78
79
        # Autoresume.
        _init_autoresume()
mshoeybi's avatar
mshoeybi committed
80

81
82
83
        # Compile dependencies.
        _compile_dependencies()

84
85
        # No continuation function
        return None
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114


def _compile_dependencies():

    args = get_args()

    # =========================
    # Compile dataset C++ code.
    # =========================
    # TODO: move this to ninja
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
        print('> compiling dataset index builder ...')
        from megatron.data.dataset_utils import compile_helper
        compile_helper()
        print('>>> done with dataset index builder. Compilation time: {:.3f} '
              'seconds'.format(time.time() - start_time), flush=True)

    # ==================
    # Load fused kernels
    # ==================

    # Custom kernel constraints check.
    seq_len = args.seq_length
    attn_batch_size = \
        (args.num_attention_heads / args.tensor_model_parallel_size) * \
        args.micro_batch_size
    # Constraints on sequence length and attn_batch_size to enable warp based
    # optimization and upper triangular optimization (for causal mask)
115
    custom_kernel_constraint = seq_len > 16 and seq_len <=4096 and \
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        seq_len % 4 == 0 and attn_batch_size % 4 == 0
    # Print a warning.
    if not ((args.fp16 or args.bf16) and
            custom_kernel_constraint and
            args.masked_softmax_fusion):
        if args.rank == 0:
            print('WARNING: constraints for invoking optimized'
                  ' fused softmax kernel are not met. We default'
                  ' back to unfused kernel invocations.', flush=True)
    
    # Always build on rank zero first.
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
        print('> compiling and loading fused kernels ...', flush=True)
        fused_kernels.load(args)
        torch.distributed.barrier()
    else:
        torch.distributed.barrier()
        fused_kernels.load(args)
    # Simple barrier to make sure all ranks have passed the
    # compilation phase successfully before moving on to the
    # rest of the program. We think this might ensure that
    # the lock is released.
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>>> done with compiling and loading fused kernels. '
              'Compilation time: {:.3f} seconds'.format(
                  time.time() - start_time), flush=True)


Mohammad's avatar
Mohammad committed
146
147

def _initialize_distributed():
148
    """Initialize torch.distributed and core model parallel."""
Mohammad's avatar
Mohammad committed
149
150
    args = get_args()

Raul Puri's avatar
Raul Puri committed
151
    device_count = torch.cuda.device_count()
Mohammad's avatar
Mohammad committed
152
153
154
155
156
157
158
159
160
161
162
163
164
    if torch.distributed.is_initialized():

        if args.rank == 0:
            print('torch distributed is already initialized, '
                  'skipping initialization ...', flush=True)
        args.rank = torch.distributed.get_rank()
        args.world_size = torch.distributed.get_world_size()

    else:

        if args.rank == 0:
            print('> initializing torch distributed ...', flush=True)
        # Manually set the device ids.
165
        if device_count > 0:
Raul Puri's avatar
Raul Puri committed
166
            device = args.rank % device_count
167
168
169
170
171
172
            if args.local_rank is not None:
                assert args.local_rank == device, \
                    'expected local-rank to be the same as rank % device-count.'
            else:
                args.local_rank = device
            torch.cuda.set_device(device)
173
174
175
176
    # Call the init process
    torch.distributed.init_process_group(
        backend=args.distributed_backend,
        world_size=args.world_size, rank=args.rank,
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
177
        timeout=timedelta(minutes=args.distributed_timeout_minutes))
Mohammad's avatar
Mohammad committed
178

179
    # Set the tensor model-parallel, pipeline model-parallel, and
180
    # data-parallel communicators.
181
    if device_count > 0:
182
183
184
        if mpu.model_parallel_is_initialized():
            print('model parallel is already initialized')
        else:
185
            mpu.initialize_model_parallel(args.tensor_model_parallel_size,
186
187
                                           args.pipeline_model_parallel_size,
                                           args.virtual_pipeline_model_parallel_size,
188
189
                                           args.pipeline_model_parallel_split_rank,
                                           args.untie_embeddings_and_output_weights)
190
191
192
193
194
            if args.rank == 0:
                print(f'> initialized tensor model parallel with size '
                      f'{mpu.get_tensor_model_parallel_world_size()}')
                print(f'> initialized pipeline model parallel with size '
                      f'{mpu.get_pipeline_model_parallel_world_size()}')
Mohammad's avatar
Mohammad committed
195
196
197
198
199
200
201
202
203
204
205


def _init_autoresume():
    """Set autoresume start time."""
    autoresume = get_adlr_autoresume()
    if autoresume:
        torch.distributed.barrier()
        autoresume.init()
        torch.distributed.barrier()


206
def _set_random_seed(seed_, data_parallel_random_init=False):
Mohammad's avatar
Mohammad committed
207
    """Set random seed for reproducability."""
208
    if seed_ is not None and seed_ > 0:
209
210
211
212
213
        # Ensure that different pipeline MP stages get different seeds.
        seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
        # Ensure different data parallel ranks get different seeds
        if data_parallel_random_init:
            seed = seed + (10 * mpu.get_data_parallel_rank())
Mohammad's avatar
Mohammad committed
214
215
216
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
217
        if torch.cuda.device_count() > 0:
218
            tensor_parallel.model_parallel_cuda_manual_seed(seed)
Mohammad's avatar
Mohammad committed
219
220
    else:
        raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
Mohammad's avatar
Mohammad committed
221
222


223
def write_args_to_tensorboard():
Mohammad's avatar
Mohammad committed
224
225
226
227
228
    """Write arguments to tensorboard."""
    args = get_args()
    writer = get_tensorboard_writer()
    if writer:
        for arg in vars(args):
229
230
            writer.add_text(arg, str(getattr(args, arg)),
                            global_step=args.iteration)
231
232


233
def set_jit_fusion_options():
Sangkug Lym's avatar
Sangkug Lym committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    """Set PyTorch JIT layer fusion options."""
    # flags required to enable jit fusion kernels
    TORCH_MAJOR = int(torch.__version__.split('.')[0])
    TORCH_MINOR = int(torch.__version__.split('.')[1])
    if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
        # nvfuser
        torch._C._jit_set_profiling_executor(True)
        torch._C._jit_set_profiling_mode(True)
        torch._C._jit_override_can_fuse_on_cpu(False)
        torch._C._jit_override_can_fuse_on_gpu(False)
        torch._C._jit_set_texpr_fuser_enabled(False)
        torch._C._jit_set_nvfuser_enabled(True)
        torch._C._debug_set_autodiff_subgraph_inlining(False)
    else:
        # legacy pytorch fuser
        torch._C._jit_set_profiling_mode(False)
        torch._C._jit_set_profiling_executor(False)
        torch._C._jit_override_can_fuse_on_cpu(True)
        torch._C._jit_override_can_fuse_on_gpu(True)
253

254
255
    _warmup_jit_function()

256

257
def _warmup_jit_function():
258
259
260
    """ Compilie JIT functions before the main training steps """
    args = get_args()
    if args.bf16:
261
        dtype = torch.bfloat16
262
    elif args.fp16:
263
        dtype = torch.float16
264
    else:
265
        dtype = torch.float32
266
267

    # Warmup fused bias+gelu
268
269
    bias = torch.rand(args.ffn_hidden_size // args.tensor_model_parallel_size,
                      dtype=dtype, device='cuda')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
270
    input = torch.rand((args.seq_length, args.micro_batch_size,
271
272
273
                        args.ffn_hidden_size // args.tensor_model_parallel_size),
                       dtype=dtype, device='cuda')
    # Warmup JIT fusions with the input grad_enable state of both forward
274
    # prop and recomputation
275
276
    for bias_grad, input_grad in zip([True, True], [False, True]):
        bias.requires_grad, input.requires_grad = bias_grad, input_grad
277
        for _ in range(5):
278
279
            output = bias_gelu(bias, input)
    del bias, input, output
280
281

    # Warmup fused bias+dropout+add
Vijay Korthikanti's avatar
Vijay Korthikanti committed
282
283
284
285
    if args.sequence_parallel:
        seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
    else:
        seq_length = args.seq_length
286
    input = torch.rand((seq_length, args.micro_batch_size, args.hidden_size),
287
                       dtype=dtype, device='cuda')
288
    residual = torch.rand((seq_length, args.micro_batch_size, args.hidden_size),
289
290
291
292
                          dtype=dtype, device='cuda')
    bias = torch.rand((args.hidden_size), dtype=dtype, device='cuda').expand_as(residual)
    dropout_rate = 0.1
    # Warmup JIT fusions with the input grad_enable state of both forward
293
    # prop and recomputation
294
295
296
297
    for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
        input.requires_grad = input_grad
        bias.requires_grad = bias_grad
        residual.requires_grad = residual_grad
298
        for _ in range(5):
299
300
            output = bias_dropout_add_fused_train(input, bias, residual, dropout_rate)
    del bias, input, residual, output
301
    torch.cuda.empty_cache()