initialize.py 11.6 KB
Newer Older
Mohammad's avatar
Mohammad committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Mohammad's avatar
Mohammad committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron initialization."""

import random
import os
20
import time
Mohammad's avatar
Mohammad committed
21

22
import numpy as np
Mohammad's avatar
Mohammad committed
23
import torch
Ryan Prenger's avatar
Ryan Prenger committed
24
from datetime import timedelta
Mohammad's avatar
Mohammad committed
25

26
from megatron import fused_kernels
27
28
29
from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
Mohammad's avatar
Mohammad committed
30
from megatron import mpu
31
32
from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import load_args_from_checkpoint
33
from megatron.global_vars import set_global_variables
34
35
from megatron.mpu import (set_tensor_model_parallel_rank,
                          set_tensor_model_parallel_world_size)
36
37
from megatron.model.transformer import bias_dropout_add_fused_train
from megatron.model.fused_bias_gelu import bias_gelu
38

Mohammad's avatar
Mohammad committed
39

40
def initialize_megatron(extra_args_provider=None, args_defaults={},
Raul Puri's avatar
Raul Puri committed
41
                        ignore_unknown_args=False, allow_no_cuda=False):
Mohammad's avatar
Mohammad committed
42
    """Set global variables, initialize distributed, and
Raul Puri's avatar
Raul Puri committed
43
44
45
    set autoresume and random seeds.
    `allow_no_cuda` should not be set unless using megatron for cpu only 
    data processing. In general this arg should not be set unless you know 
46
47
    what you are doing.
    Returns a function to finalize distributed env initialization 
Boris Fomitchev's avatar
Boris Fomitchev committed
48
    (optionally, only when args.lazy_mpu_init == True)
49
    """
Raul Puri's avatar
Raul Puri committed
50
51
52
    if not allow_no_cuda:
        # Make sure cuda is available.
        assert torch.cuda.is_available(), 'Megatron requires CUDA.'
Mohammad's avatar
Mohammad committed
53

54
55
56
57
58
59
60
61
62
63
    # Parse arguments
    args = parse_args(extra_args_provider, ignore_unknown_args)

    if args.use_checkpoint_args or args_defaults.get('use_checkpoint_args', False):
        assert args.load is not None, '--use-checkpoints-args requires --load argument'
        load_args_from_checkpoint(args)

    validate_args(args, args_defaults)
        
    # set global args, build tokenizer, and set adlr-autoresume,
Mohammad's avatar
Mohammad committed
64
    # tensorboard-writer, and timers.
65
    set_global_variables(args)
Mohammad's avatar
Mohammad committed
66

67
    # torch.distributed initialization
68
    def finish_mpu_init():
69
70
71
72
73
74
75
        args = get_args()
        # Pytorch distributed.
        _initialize_distributed()
        
        # Random seeds for reproducibility.
        if args.rank == 0:
            print('> setting random seeds to {} ...'.format(args.seed))
76
        _set_random_seed(args.seed, args.data_parallel_random_init)
Mohammad's avatar
Mohammad committed
77
78

    args = get_args()
79
    if  args.lazy_mpu_init:
80
        args.use_cpu_initialization=True
81
82
        # delayed initialization of DDP-related stuff
        # We only set basic DDP globals    
83
        set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
84
85
        # and return function for external DDP manager
        # to call when it has DDP initialized
86
        set_tensor_model_parallel_rank(args.rank)    
87
        return finish_mpu_init
88
    else:
89
90
        # Megatron's MPU is the master. Complete initialization right away.
        finish_mpu_init()
91

92
93
        # Autoresume.
        _init_autoresume()
mshoeybi's avatar
mshoeybi committed
94

95
96
97
        # Compile dependencies.
        _compile_dependencies()

98
99
        # No continuation function
        return None
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128


def _compile_dependencies():

    args = get_args()

    # =========================
    # Compile dataset C++ code.
    # =========================
    # TODO: move this to ninja
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
        print('> compiling dataset index builder ...')
        from megatron.data.dataset_utils import compile_helper
        compile_helper()
        print('>>> done with dataset index builder. Compilation time: {:.3f} '
              'seconds'.format(time.time() - start_time), flush=True)

    # ==================
    # Load fused kernels
    # ==================

    # Custom kernel constraints check.
    seq_len = args.seq_length
    attn_batch_size = \
        (args.num_attention_heads / args.tensor_model_parallel_size) * \
        args.micro_batch_size
    # Constraints on sequence length and attn_batch_size to enable warp based
    # optimization and upper triangular optimization (for causal mask)
129
    custom_kernel_constraint = seq_len > 16 and seq_len <=4096 and \
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        seq_len % 4 == 0 and attn_batch_size % 4 == 0
    # Print a warning.
    if not ((args.fp16 or args.bf16) and
            custom_kernel_constraint and
            args.masked_softmax_fusion):
        if args.rank == 0:
            print('WARNING: constraints for invoking optimized'
                  ' fused softmax kernel are not met. We default'
                  ' back to unfused kernel invocations.', flush=True)
    
    # Always build on rank zero first.
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
        print('> compiling and loading fused kernels ...', flush=True)
        fused_kernels.load(args)
        torch.distributed.barrier()
    else:
        torch.distributed.barrier()
        fused_kernels.load(args)
    # Simple barrier to make sure all ranks have passed the
    # compilation phase successfully before moving on to the
    # rest of the program. We think this might ensure that
    # the lock is released.
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>>> done with compiling and loading fused kernels. '
              'Compilation time: {:.3f} seconds'.format(
                  time.time() - start_time), flush=True)


Mohammad's avatar
Mohammad committed
160
161
162
163
164

def _initialize_distributed():
    """Initialize torch.distributed and mpu."""
    args = get_args()

Raul Puri's avatar
Raul Puri committed
165
    device_count = torch.cuda.device_count()
Mohammad's avatar
Mohammad committed
166
167
168
169
170
171
172
173
174
175
176
177
178
    if torch.distributed.is_initialized():

        if args.rank == 0:
            print('torch distributed is already initialized, '
                  'skipping initialization ...', flush=True)
        args.rank = torch.distributed.get_rank()
        args.world_size = torch.distributed.get_world_size()

    else:

        if args.rank == 0:
            print('> initializing torch distributed ...', flush=True)
        # Manually set the device ids.
179
        if device_count > 0:
Raul Puri's avatar
Raul Puri committed
180
            device = args.rank % device_count
181
182
183
184
185
186
            if args.local_rank is not None:
                assert args.local_rank == device, \
                    'expected local-rank to be the same as rank % device-count.'
            else:
                args.local_rank = device
            torch.cuda.set_device(device)
187
188
189
190
    # Call the init process
    torch.distributed.init_process_group(
        backend=args.distributed_backend,
        world_size=args.world_size, rank=args.rank,
191
        timeout=timedelta(minutes=10))
Mohammad's avatar
Mohammad committed
192

193
    # Set the tensor model-parallel, pipeline model-parallel, and
194
    # data-parallel communicators.
195
    if device_count > 0:
196
197
198
        if mpu.model_parallel_is_initialized():
            print('model parallel is already initialized')
        else:
199
            mpu.initialize_model_parallel(args.tensor_model_parallel_size,
200
                                          args.pipeline_model_parallel_size,
201
202
                                          args.virtual_pipeline_model_parallel_size,
                                          args.pipeline_model_parallel_split_rank)
Mohammad's avatar
Mohammad committed
203
204
205
206
207
208
209
210
211
212
213


def _init_autoresume():
    """Set autoresume start time."""
    autoresume = get_adlr_autoresume()
    if autoresume:
        torch.distributed.barrier()
        autoresume.init()
        torch.distributed.barrier()


214
def _set_random_seed(seed_, data_parallel_random_init=False):
Mohammad's avatar
Mohammad committed
215
    """Set random seed for reproducability."""
216
    if seed_ is not None and seed_ > 0:
217
218
219
220
221
        # Ensure that different pipeline MP stages get different seeds.
        seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
        # Ensure different data parallel ranks get different seeds
        if data_parallel_random_init:
            seed = seed + (10 * mpu.get_data_parallel_rank())
Mohammad's avatar
Mohammad committed
222
223
224
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
225
        if torch.cuda.device_count() > 0:
226
            mpu.model_parallel_cuda_manual_seed(seed)
Mohammad's avatar
Mohammad committed
227
228
    else:
        raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
Mohammad's avatar
Mohammad committed
229
230


231
def write_args_to_tensorboard():
Mohammad's avatar
Mohammad committed
232
233
234
235
236
    """Write arguments to tensorboard."""
    args = get_args()
    writer = get_tensorboard_writer()
    if writer:
        for arg in vars(args):
237
238
            writer.add_text(arg, str(getattr(args, arg)),
                            global_step=args.iteration)
239
240


241
def set_jit_fusion_options():
Sangkug Lym's avatar
Sangkug Lym committed
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    """Set PyTorch JIT layer fusion options."""
    # flags required to enable jit fusion kernels
    TORCH_MAJOR = int(torch.__version__.split('.')[0])
    TORCH_MINOR = int(torch.__version__.split('.')[1])
    if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
        # nvfuser
        torch._C._jit_set_profiling_executor(True)
        torch._C._jit_set_profiling_mode(True)
        torch._C._jit_override_can_fuse_on_cpu(False)
        torch._C._jit_override_can_fuse_on_gpu(False)
        torch._C._jit_set_texpr_fuser_enabled(False)
        torch._C._jit_set_nvfuser_enabled(True)
        torch._C._debug_set_autodiff_subgraph_inlining(False)
    else:
        # legacy pytorch fuser
        torch._C._jit_set_profiling_mode(False)
        torch._C._jit_set_profiling_executor(False)
        torch._C._jit_override_can_fuse_on_cpu(True)
        torch._C._jit_override_can_fuse_on_gpu(True)
261

262
263
    _warmup_jit_function()

264

265
def _warmup_jit_function():
266
267
268
    """ Compilie JIT functions before the main training steps """
    args = get_args()
    if args.bf16:
269
        dtype = torch.bfloat16
270
    elif args.fp16:
271
        dtype = torch.float16
272
    else:
273
        dtype = torch.float32
274
275

    # Warmup fused bias+gelu
276
277
    bias = torch.rand(args.ffn_hidden_size // args.tensor_model_parallel_size,
                      dtype=dtype, device='cuda')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
278
    input = torch.rand((args.seq_length, args.micro_batch_size,
279
280
281
                        args.ffn_hidden_size // args.tensor_model_parallel_size),
                       dtype=dtype, device='cuda')
    # Warmup JIT fusions with the input grad_enable state of both forward
282
    # prop and recomputation
283
284
    for bias_grad, input_grad in zip([True, True], [False, True]):
        bias.requires_grad, input.requires_grad = bias_grad, input_grad
285
        for _ in range(5):
286
287
            output = bias_gelu(bias, input)
    del bias, input, output
288
289

    # Warmup fused bias+dropout+add
Vijay Korthikanti's avatar
Vijay Korthikanti committed
290
291
292
293
    if args.sequence_parallel:
        seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
    else:
        seq_length = args.seq_length
294
    input = torch.rand((seq_length, args.micro_batch_size, args.hidden_size),
295
                       dtype=dtype, device='cuda')
296
    residual = torch.rand((seq_length, args.micro_batch_size, args.hidden_size),
297
298
299
300
                          dtype=dtype, device='cuda')
    bias = torch.rand((args.hidden_size), dtype=dtype, device='cuda').expand_as(residual)
    dropout_rate = 0.1
    # Warmup JIT fusions with the input grad_enable state of both forward
301
    # prop and recomputation
302
303
304
305
    for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
        input.requires_grad = input_grad
        bias.requires_grad = bias_grad
        residual.requires_grad = residual_grad
306
        for _ in range(5):
307
308
            output = bias_dropout_add_fused_train(input, bias, residual, dropout_rate)
    del bias, input, residual, output
309
    torch.cuda.empty_cache()