initialize.py 10.9 KB
Newer Older
Mohammad's avatar
Mohammad committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Mohammad's avatar
Mohammad committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron initialization."""

import random
import os
20
import time
Mohammad's avatar
Mohammad committed
21

22
import numpy as np
Mohammad's avatar
Mohammad committed
23
import torch
Ryan Prenger's avatar
Ryan Prenger committed
24
from datetime import timedelta
Mohammad's avatar
Mohammad committed
25

26
from megatron import fused_kernels
27
28
29
from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
Mohammad's avatar
Mohammad committed
30
from megatron import mpu
31
from megatron.global_vars import set_global_variables
32
33
from megatron.mpu import (set_tensor_model_parallel_rank,
                          set_tensor_model_parallel_world_size)
34
35
from megatron.model.transformer import bias_dropout_add_fused_train
from megatron.model.fused_bias_gelu import bias_gelu
36

Mohammad's avatar
Mohammad committed
37

38
def initialize_megatron(extra_args_provider=None, args_defaults={},
Raul Puri's avatar
Raul Puri committed
39
                        ignore_unknown_args=False, allow_no_cuda=False):
Mohammad's avatar
Mohammad committed
40
    """Set global variables, initialize distributed, and
Raul Puri's avatar
Raul Puri committed
41
42
43
    set autoresume and random seeds.
    `allow_no_cuda` should not be set unless using megatron for cpu only 
    data processing. In general this arg should not be set unless you know 
44
45
    what you are doing.
    Returns a function to finalize distributed env initialization 
Boris Fomitchev's avatar
Boris Fomitchev committed
46
    (optionally, only when args.lazy_mpu_init == True)
47
    """
Raul Puri's avatar
Raul Puri committed
48
49
50
    if not allow_no_cuda:
        # Make sure cuda is available.
        assert torch.cuda.is_available(), 'Megatron requires CUDA.'
Mohammad's avatar
Mohammad committed
51

Mohammad's avatar
Mohammad committed
52
53
    # Parse args, build tokenizer, and set adlr-autoresume,
    # tensorboard-writer, and timers.
Mohammad's avatar
Mohammad committed
54
    set_global_variables(extra_args_provider=extra_args_provider,
55
56
                         args_defaults=args_defaults,
                         ignore_unknown_args=ignore_unknown_args)
Mohammad's avatar
Mohammad committed
57

58
    # torch.distributed initialization
59
    def finish_mpu_init():
60
61
62
63
64
65
66
        args = get_args()
        # Pytorch distributed.
        _initialize_distributed()
        
        # Random seeds for reproducibility.
        if args.rank == 0:
            print('> setting random seeds to {} ...'.format(args.seed))
67
        _set_random_seed(args.seed, args.data_parallel_random_init)
Mohammad's avatar
Mohammad committed
68

Sangkug Lym's avatar
Sangkug Lym committed
69
70
71
    # Set pytorch JIT layer fusion options.
    _set_jit_fusion_options()

Mohammad's avatar
Mohammad committed
72
    args = get_args()
73
    if  args.lazy_mpu_init:
74
        args.use_cpu_initialization=True
75
76
        # delayed initialization of DDP-related stuff
        # We only set basic DDP globals    
77
        set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
78
79
        # and return function for external DDP manager
        # to call when it has DDP initialized
80
        set_tensor_model_parallel_rank(args.rank)    
81
        return finish_mpu_init
82
    else:
83
84
        # Megatron's MPU is the master. Complete initialization right away.
        finish_mpu_init()
85

86
87
        # Autoresume.
        _init_autoresume()
mshoeybi's avatar
mshoeybi committed
88

89
90
91
        # Compile dependencies.
        _compile_dependencies()

92
93
        # No continuation function
        return None
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122


def _compile_dependencies():

    args = get_args()

    # =========================
    # Compile dataset C++ code.
    # =========================
    # TODO: move this to ninja
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
        print('> compiling dataset index builder ...')
        from megatron.data.dataset_utils import compile_helper
        compile_helper()
        print('>>> done with dataset index builder. Compilation time: {:.3f} '
              'seconds'.format(time.time() - start_time), flush=True)

    # ==================
    # Load fused kernels
    # ==================

    # Custom kernel constraints check.
    seq_len = args.seq_length
    attn_batch_size = \
        (args.num_attention_heads / args.tensor_model_parallel_size) * \
        args.micro_batch_size
    # Constraints on sequence length and attn_batch_size to enable warp based
    # optimization and upper triangular optimization (for causal mask)
123
    custom_kernel_constraint = seq_len > 16 and seq_len <=4096 and \
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        seq_len % 4 == 0 and attn_batch_size % 4 == 0
    # Print a warning.
    if not ((args.fp16 or args.bf16) and
            custom_kernel_constraint and
            args.masked_softmax_fusion):
        if args.rank == 0:
            print('WARNING: constraints for invoking optimized'
                  ' fused softmax kernel are not met. We default'
                  ' back to unfused kernel invocations.', flush=True)
    
    # Always build on rank zero first.
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
        print('> compiling and loading fused kernels ...', flush=True)
        fused_kernels.load(args)
        torch.distributed.barrier()
    else:
        torch.distributed.barrier()
        fused_kernels.load(args)
    # Simple barrier to make sure all ranks have passed the
    # compilation phase successfully before moving on to the
    # rest of the program. We think this might ensure that
    # the lock is released.
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>>> done with compiling and loading fused kernels. '
              'Compilation time: {:.3f} seconds'.format(
                  time.time() - start_time), flush=True)


Mohammad's avatar
Mohammad committed
154
155
156
157
158

def _initialize_distributed():
    """Initialize torch.distributed and mpu."""
    args = get_args()

Raul Puri's avatar
Raul Puri committed
159
    device_count = torch.cuda.device_count()
Mohammad's avatar
Mohammad committed
160
161
162
163
164
165
166
167
168
169
170
171
172
    if torch.distributed.is_initialized():

        if args.rank == 0:
            print('torch distributed is already initialized, '
                  'skipping initialization ...', flush=True)
        args.rank = torch.distributed.get_rank()
        args.world_size = torch.distributed.get_world_size()

    else:

        if args.rank == 0:
            print('> initializing torch distributed ...', flush=True)
        # Manually set the device ids.
173
        if device_count > 0:
Raul Puri's avatar
Raul Puri committed
174
            device = args.rank % device_count
175
176
177
178
179
180
            if args.local_rank is not None:
                assert args.local_rank == device, \
                    'expected local-rank to be the same as rank % device-count.'
            else:
                args.local_rank = device
            torch.cuda.set_device(device)
181
182
183
184
    # Call the init process
    torch.distributed.init_process_group(
        backend=args.distributed_backend,
        world_size=args.world_size, rank=args.rank,
185
        timeout=timedelta(minutes=10))
Mohammad's avatar
Mohammad committed
186

187
    # Set the tensor model-parallel, pipeline model-parallel, and
188
    # data-parallel communicators.
189
    if device_count > 0:
190
191
192
        if mpu.model_parallel_is_initialized():
            print('model parallel is already initialized')
        else:
193
            mpu.initialize_model_parallel(args.tensor_model_parallel_size,
194
                                          args.pipeline_model_parallel_size,
195
196
                                          args.virtual_pipeline_model_parallel_size,
                                          args.pipeline_model_parallel_split_rank)
Mohammad's avatar
Mohammad committed
197
198
199
200
201
202
203
204
205
206
207


def _init_autoresume():
    """Set autoresume start time."""
    autoresume = get_adlr_autoresume()
    if autoresume:
        torch.distributed.barrier()
        autoresume.init()
        torch.distributed.barrier()


208
def _set_random_seed(seed_, data_parallel_random_init=False):
Mohammad's avatar
Mohammad committed
209
    """Set random seed for reproducability."""
210
    if seed_ is not None and seed_ > 0:
211
212
213
214
215
        # Ensure that different pipeline MP stages get different seeds.
        seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
        # Ensure different data parallel ranks get different seeds
        if data_parallel_random_init:
            seed = seed + (10 * mpu.get_data_parallel_rank())
Mohammad's avatar
Mohammad committed
216
217
218
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
219
        if torch.cuda.device_count() > 0:
220
            mpu.model_parallel_cuda_manual_seed(seed)
Mohammad's avatar
Mohammad committed
221
222
    else:
        raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
Mohammad's avatar
Mohammad committed
223
224


225
def write_args_to_tensorboard():
Mohammad's avatar
Mohammad committed
226
227
228
229
230
    """Write arguments to tensorboard."""
    args = get_args()
    writer = get_tensorboard_writer()
    if writer:
        for arg in vars(args):
231
232
            writer.add_text(arg, str(getattr(args, arg)),
                            global_step=args.iteration)
233
234


Sangkug Lym's avatar
Sangkug Lym committed
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def _set_jit_fusion_options():
    """Set PyTorch JIT layer fusion options."""
    # flags required to enable jit fusion kernels
    TORCH_MAJOR = int(torch.__version__.split('.')[0])
    TORCH_MINOR = int(torch.__version__.split('.')[1])
    if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
        # nvfuser
        torch._C._jit_set_profiling_executor(True)
        torch._C._jit_set_profiling_mode(True)
        torch._C._jit_override_can_fuse_on_cpu(False)
        torch._C._jit_override_can_fuse_on_gpu(False)
        torch._C._jit_set_texpr_fuser_enabled(False)
        torch._C._jit_set_nvfuser_enabled(True)
        torch._C._debug_set_autodiff_subgraph_inlining(False)
    else:
        # legacy pytorch fuser
        torch._C._jit_set_profiling_mode(False)
        torch._C._jit_set_profiling_executor(False)
        torch._C._jit_override_can_fuse_on_cpu(True)
        torch._C._jit_override_can_fuse_on_gpu(True)
255

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293

def warmup_jit_function():
    """ Compilie JIT functions before the main training steps """
    args = get_args()
    if args.bf16:
        p = torch.bfloat16
    elif args.fp16:
        p = torch.float16
    else:
        p = torch.float32

    # Warmup fused bias+gelu
    b = torch.rand(int(args.hidden_size * 4 / args.tensor_model_parallel_size),
                   dtype=p, device='cuda')
    x = torch.rand((args.seq_length, args.micro_batch_size,
                    int(args.hidden_size * 4 / args.tensor_model_parallel_size)),
                   dtype=p, device='cuda')
    # Warmup JIT fusions with the input grad_enable state at both forward
    # prop and recomputation
    for b_grad, x_grad in zip([True, True], [False, True]):
        b.requires_grad, x.requires_grad = b_grad, x_grad
        for _ in range(5):
            y = bias_gelu(b, x)
    del b, x, y

    # Warmup fused bias+dropout+add
    input_size = (args.seq_length, args.micro_batch_size, args.hidden_size)
    x = torch.rand(input_size, dtype=p, device='cuda')
    r = torch.rand(input_size, dtype=p, device='cuda')
    b = torch.rand((args.hidden_size), dtype=p, device='cuda').expand_as(r)
    # Warmup JIT fusions with the input grad_enable state at both forward
    # prop and recomputation
    for x_grad, b_grad, r_grad in zip([False, True], [True, True], [True, True]):
        x.requires_grad, b.requires_grad, r.requires_grad = x_grad, b_grad, r_grad
        for _ in range(5):
            y = bias_dropout_add_fused_train(x, b, r, 0.1)
    del b, x, r, y
    torch.cuda.empty_cache()