"docs/references/setup_github_runner.md" did not exist on "566d61d90fd508f09179788e8b719a748af8e65b"
initialize.py 11.7 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mohammad's avatar
Mohammad committed
2
3
4
5
6

"""Megatron initialization."""

import random
import os
7
import time
Mohammad's avatar
Mohammad committed
8

9
import numpy as np
Mohammad's avatar
Mohammad committed
10
import torch
Ryan Prenger's avatar
Ryan Prenger committed
11
from datetime import timedelta
Mohammad's avatar
Mohammad committed
12

13
from megatron import fused_kernels
14
15
16
from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
Mohammad's avatar
Mohammad committed
17
from megatron import mpu
18
from megatron import core
19
20
from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import load_args_from_checkpoint
21
from megatron.global_vars import set_global_variables
22
23
from megatron.mpu import (set_tensor_model_parallel_rank,
                          set_tensor_model_parallel_world_size)
24
25
from megatron.model.transformer import bias_dropout_add_fused_train
from megatron.model.fused_bias_gelu import bias_gelu
26

Mohammad's avatar
Mohammad committed
27

28
def initialize_megatron(extra_args_provider=None, args_defaults={},
Raul Puri's avatar
Raul Puri committed
29
                        ignore_unknown_args=False, allow_no_cuda=False):
Mohammad's avatar
Mohammad committed
30
    """Set global variables, initialize distributed, and
Raul Puri's avatar
Raul Puri committed
31
32
33
    set autoresume and random seeds.
    `allow_no_cuda` should not be set unless using megatron for cpu only 
    data processing. In general this arg should not be set unless you know 
34
35
    what you are doing.
    Returns a function to finalize distributed env initialization 
Boris Fomitchev's avatar
Boris Fomitchev committed
36
    (optionally, only when args.lazy_mpu_init == True)
37
    """
Raul Puri's avatar
Raul Puri committed
38
39
40
    if not allow_no_cuda:
        # Make sure cuda is available.
        assert torch.cuda.is_available(), 'Megatron requires CUDA.'
Mohammad's avatar
Mohammad committed
41

42
43
44
45
46
47
48
49
50
51
    # Parse arguments
    args = parse_args(extra_args_provider, ignore_unknown_args)

    if args.use_checkpoint_args or args_defaults.get('use_checkpoint_args', False):
        assert args.load is not None, '--use-checkpoints-args requires --load argument'
        load_args_from_checkpoint(args)

    validate_args(args, args_defaults)
        
    # set global args, build tokenizer, and set adlr-autoresume,
Mohammad's avatar
Mohammad committed
52
    # tensorboard-writer, and timers.
53
    set_global_variables(args)
Mohammad's avatar
Mohammad committed
54

55
    # torch.distributed initialization
56
    def finish_mpu_init():
57
58
59
60
61
62
63
        args = get_args()
        # Pytorch distributed.
        _initialize_distributed()
        
        # Random seeds for reproducibility.
        if args.rank == 0:
            print('> setting random seeds to {} ...'.format(args.seed))
64
        _set_random_seed(args.seed, args.data_parallel_random_init)
Mohammad's avatar
Mohammad committed
65
66

    args = get_args()
67
    if  args.lazy_mpu_init:
68
        args.use_cpu_initialization=True
69
70
        # delayed initialization of DDP-related stuff
        # We only set basic DDP globals    
71
        set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
72
73
        # and return function for external DDP manager
        # to call when it has DDP initialized
74
        set_tensor_model_parallel_rank(args.rank)    
75
        return finish_mpu_init
76
    else:
77
78
        # Megatron's MPU is the master. Complete initialization right away.
        finish_mpu_init()
79

80
81
        # Autoresume.
        _init_autoresume()
mshoeybi's avatar
mshoeybi committed
82

83
84
85
        # Compile dependencies.
        _compile_dependencies()

86
87
        # No continuation function
        return None
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116


def _compile_dependencies():

    args = get_args()

    # =========================
    # Compile dataset C++ code.
    # =========================
    # TODO: move this to ninja
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
        print('> compiling dataset index builder ...')
        from megatron.data.dataset_utils import compile_helper
        compile_helper()
        print('>>> done with dataset index builder. Compilation time: {:.3f} '
              'seconds'.format(time.time() - start_time), flush=True)

    # ==================
    # Load fused kernels
    # ==================

    # Custom kernel constraints check.
    seq_len = args.seq_length
    attn_batch_size = \
        (args.num_attention_heads / args.tensor_model_parallel_size) * \
        args.micro_batch_size
    # Constraints on sequence length and attn_batch_size to enable warp based
    # optimization and upper triangular optimization (for causal mask)
117
    custom_kernel_constraint = seq_len > 16 and seq_len <=4096 and \
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        seq_len % 4 == 0 and attn_batch_size % 4 == 0
    # Print a warning.
    if not ((args.fp16 or args.bf16) and
            custom_kernel_constraint and
            args.masked_softmax_fusion):
        if args.rank == 0:
            print('WARNING: constraints for invoking optimized'
                  ' fused softmax kernel are not met. We default'
                  ' back to unfused kernel invocations.', flush=True)
    
    # Always build on rank zero first.
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
        print('> compiling and loading fused kernels ...', flush=True)
        fused_kernels.load(args)
        torch.distributed.barrier()
    else:
        torch.distributed.barrier()
        fused_kernels.load(args)
    # Simple barrier to make sure all ranks have passed the
    # compilation phase successfully before moving on to the
    # rest of the program. We think this might ensure that
    # the lock is released.
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>>> done with compiling and loading fused kernels. '
              'Compilation time: {:.3f} seconds'.format(
                  time.time() - start_time), flush=True)


Mohammad's avatar
Mohammad committed
148
149
150
151
152

def _initialize_distributed():
    """Initialize torch.distributed and mpu."""
    args = get_args()

Raul Puri's avatar
Raul Puri committed
153
    device_count = torch.cuda.device_count()
Mohammad's avatar
Mohammad committed
154
155
156
157
158
159
160
161
162
163
164
165
166
    if torch.distributed.is_initialized():

        if args.rank == 0:
            print('torch distributed is already initialized, '
                  'skipping initialization ...', flush=True)
        args.rank = torch.distributed.get_rank()
        args.world_size = torch.distributed.get_world_size()

    else:

        if args.rank == 0:
            print('> initializing torch distributed ...', flush=True)
        # Manually set the device ids.
167
        if device_count > 0:
Raul Puri's avatar
Raul Puri committed
168
            device = args.rank % device_count
169
170
171
172
173
174
            if args.local_rank is not None:
                assert args.local_rank == device, \
                    'expected local-rank to be the same as rank % device-count.'
            else:
                args.local_rank = device
            torch.cuda.set_device(device)
175
176
177
178
    # Call the init process
    torch.distributed.init_process_group(
        backend=args.distributed_backend,
        world_size=args.world_size, rank=args.rank,
179
        timeout=timedelta(minutes=10))
Mohammad's avatar
Mohammad committed
180

181
    # Set the tensor model-parallel, pipeline model-parallel, and
182
    # data-parallel communicators.
183
    if device_count > 0:
184
185
186
        if mpu.model_parallel_is_initialized():
            print('model parallel is already initialized')
        else:
187
            mpu.initialize_model_parallel(args.tensor_model_parallel_size,
188
                                          args.pipeline_model_parallel_size,
189
190
                                          args.virtual_pipeline_model_parallel_size,
                                          args.pipeline_model_parallel_split_rank)
191
192
193
194
195
196
197
198
            core.initialize_model_parallel(args.tensor_model_parallel_size,
                                           args.pipeline_model_parallel_size,
                                           args.virtual_pipeline_model_parallel_size,
                                           args.pipeline_model_parallel_split_rank)
            print(f'> initialized tensor model parallel with size '
                  f'{core.get_tensor_model_parallel_world_size()}')
            print(f'> initialized pipeline model parallel with size '
                  f'{core.get_pipeline_model_parallel_world_size()}')
Mohammad's avatar
Mohammad committed
199
200
201
202
203
204
205
206
207
208
209


def _init_autoresume():
    """Set autoresume start time."""
    autoresume = get_adlr_autoresume()
    if autoresume:
        torch.distributed.barrier()
        autoresume.init()
        torch.distributed.barrier()


210
def _set_random_seed(seed_, data_parallel_random_init=False):
Mohammad's avatar
Mohammad committed
211
    """Set random seed for reproducability."""
212
    if seed_ is not None and seed_ > 0:
213
214
215
216
217
        # Ensure that different pipeline MP stages get different seeds.
        seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
        # Ensure different data parallel ranks get different seeds
        if data_parallel_random_init:
            seed = seed + (10 * mpu.get_data_parallel_rank())
Mohammad's avatar
Mohammad committed
218
219
220
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
221
        if torch.cuda.device_count() > 0:
222
            mpu.model_parallel_cuda_manual_seed(seed)
Mohammad's avatar
Mohammad committed
223
224
    else:
        raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
Mohammad's avatar
Mohammad committed
225
226


227
def write_args_to_tensorboard():
Mohammad's avatar
Mohammad committed
228
229
230
231
232
    """Write arguments to tensorboard."""
    args = get_args()
    writer = get_tensorboard_writer()
    if writer:
        for arg in vars(args):
233
234
            writer.add_text(arg, str(getattr(args, arg)),
                            global_step=args.iteration)
235
236


237
def set_jit_fusion_options():
Sangkug Lym's avatar
Sangkug Lym committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    """Set PyTorch JIT layer fusion options."""
    # flags required to enable jit fusion kernels
    TORCH_MAJOR = int(torch.__version__.split('.')[0])
    TORCH_MINOR = int(torch.__version__.split('.')[1])
    if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
        # nvfuser
        torch._C._jit_set_profiling_executor(True)
        torch._C._jit_set_profiling_mode(True)
        torch._C._jit_override_can_fuse_on_cpu(False)
        torch._C._jit_override_can_fuse_on_gpu(False)
        torch._C._jit_set_texpr_fuser_enabled(False)
        torch._C._jit_set_nvfuser_enabled(True)
        torch._C._debug_set_autodiff_subgraph_inlining(False)
    else:
        # legacy pytorch fuser
        torch._C._jit_set_profiling_mode(False)
        torch._C._jit_set_profiling_executor(False)
        torch._C._jit_override_can_fuse_on_cpu(True)
        torch._C._jit_override_can_fuse_on_gpu(True)
257

258
259
    _warmup_jit_function()

260

261
def _warmup_jit_function():
262
263
264
    """ Compilie JIT functions before the main training steps """
    args = get_args()
    if args.bf16:
265
        dtype = torch.bfloat16
266
    elif args.fp16:
267
        dtype = torch.float16
268
    else:
269
        dtype = torch.float32
270
271

    # Warmup fused bias+gelu
272
273
    bias = torch.rand(args.ffn_hidden_size // args.tensor_model_parallel_size,
                      dtype=dtype, device='cuda')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
274
    input = torch.rand((args.seq_length, args.micro_batch_size,
275
276
277
                        args.ffn_hidden_size // args.tensor_model_parallel_size),
                       dtype=dtype, device='cuda')
    # Warmup JIT fusions with the input grad_enable state of both forward
278
    # prop and recomputation
279
280
    for bias_grad, input_grad in zip([True, True], [False, True]):
        bias.requires_grad, input.requires_grad = bias_grad, input_grad
281
        for _ in range(5):
282
283
            output = bias_gelu(bias, input)
    del bias, input, output
284
285

    # Warmup fused bias+dropout+add
Vijay Korthikanti's avatar
Vijay Korthikanti committed
286
287
288
289
    if args.sequence_parallel:
        seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
    else:
        seq_length = args.seq_length
290
    input = torch.rand((seq_length, args.micro_batch_size, args.hidden_size),
291
                       dtype=dtype, device='cuda')
292
    residual = torch.rand((seq_length, args.micro_batch_size, args.hidden_size),
293
294
295
296
                          dtype=dtype, device='cuda')
    bias = torch.rand((args.hidden_size), dtype=dtype, device='cuda').expand_as(residual)
    dropout_rate = 0.1
    # Warmup JIT fusions with the input grad_enable state of both forward
297
    # prop and recomputation
298
299
300
301
    for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
        input.requires_grad = input_grad
        bias.requires_grad = bias_grad
        residual.requires_grad = residual_grad
302
        for _ in range(5):
303
304
            output = bias_dropout_add_fused_train(input, bias, residual, dropout_rate)
    del bias, input, residual, output
305
    torch.cuda.empty_cache()