initialize.py 8.44 KB
Newer Older
Mohammad's avatar
Mohammad committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Mohammad's avatar
Mohammad committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron initialization."""

import random
import os
20
import time
Mohammad's avatar
Mohammad committed
21

22
import numpy as np
Mohammad's avatar
Mohammad committed
23
24
import torch

25
from megatron import fused_kernels
26
27
28
from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
Mohammad's avatar
Mohammad committed
29
from megatron import mpu
30
from megatron.global_vars import set_global_variables
31
32
33
from megatron.mpu import (set_tensor_model_parallel_rank,
                          set_tensor_model_parallel_world_size)

Mohammad's avatar
Mohammad committed
34

35
def initialize_megatron(extra_args_provider=None, args_defaults={},
Raul Puri's avatar
Raul Puri committed
36
                        ignore_unknown_args=False, allow_no_cuda=False):
Mohammad's avatar
Mohammad committed
37
    """Set global variables, initialize distributed, and
Raul Puri's avatar
Raul Puri committed
38
39
40
    set autoresume and random seeds.
    `allow_no_cuda` should not be set unless using megatron for cpu only 
    data processing. In general this arg should not be set unless you know 
41
42
    what you are doing.
    Returns a function to finalize distributed env initialization 
Boris Fomitchev's avatar
Boris Fomitchev committed
43
    (optionally, only when args.lazy_mpu_init == True)
44
    """
Raul Puri's avatar
Raul Puri committed
45
46
47
    if not allow_no_cuda:
        # Make sure cuda is available.
        assert torch.cuda.is_available(), 'Megatron requires CUDA.'
Mohammad's avatar
Mohammad committed
48

Mohammad's avatar
Mohammad committed
49
50
    # Parse args, build tokenizer, and set adlr-autoresume,
    # tensorboard-writer, and timers.
Mohammad's avatar
Mohammad committed
51
    set_global_variables(extra_args_provider=extra_args_provider,
52
53
                         args_defaults=args_defaults,
                         ignore_unknown_args=ignore_unknown_args)
Mohammad's avatar
Mohammad committed
54

55
    # torch.distributed initialization
56
    def finish_mpu_init():
57
58
59
60
61
62
63
64
        args = get_args()
        # Pytorch distributed.
        _initialize_distributed()
        
        # Random seeds for reproducibility.
        if args.rank == 0:
            print('> setting random seeds to {} ...'.format(args.seed))
        _set_random_seed(args.seed)
Mohammad's avatar
Mohammad committed
65
66

    args = get_args()
67
    if  args.lazy_mpu_init:
68
        args.use_cpu_initialization=True
69
70
        # delayed initialization of DDP-related stuff
        # We only set basic DDP globals    
71
        set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
72
73
        # and return function for external DDP manager
        # to call when it has DDP initialized
74
        set_tensor_model_parallel_rank(args.rank)    
75
        return finish_mpu_init
76
    else:
77
78
        # Megatron's MPU is the master. Complete initialization right away.
        finish_mpu_init()
79
80
81

        # Initialize memory buffers.
        _initialize_mem_buffs()
82
83
84
        
        # Autoresume.
        _init_autoresume()
mshoeybi's avatar
mshoeybi committed
85

86
87
88
        # Compile dependencies.
        _compile_dependencies()

89
90
        # No continuation function
        return None
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150


def _compile_dependencies():

    args = get_args()

    # =========================
    # Compile dataset C++ code.
    # =========================
    # TODO: move this to ninja
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
        print('> compiling dataset index builder ...')
        from megatron.data.dataset_utils import compile_helper
        compile_helper()
        print('>>> done with dataset index builder. Compilation time: {:.3f} '
              'seconds'.format(time.time() - start_time), flush=True)

    # ==================
    # Load fused kernels
    # ==================

    # Custom kernel constraints check.
    seq_len = args.seq_length
    attn_batch_size = \
        (args.num_attention_heads / args.tensor_model_parallel_size) * \
        args.micro_batch_size
    # Constraints on sequence length and attn_batch_size to enable warp based
    # optimization and upper triangular optimization (for causal mask)
    custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
        seq_len % 4 == 0 and attn_batch_size % 4 == 0
    # Print a warning.
    if not ((args.fp16 or args.bf16) and
            custom_kernel_constraint and
            args.masked_softmax_fusion):
        if args.rank == 0:
            print('WARNING: constraints for invoking optimized'
                  ' fused softmax kernel are not met. We default'
                  ' back to unfused kernel invocations.', flush=True)
    
    # Always build on rank zero first.
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
        print('> compiling and loading fused kernels ...', flush=True)
        fused_kernels.load(args)
        torch.distributed.barrier()
    else:
        torch.distributed.barrier()
        fused_kernels.load(args)
    # Simple barrier to make sure all ranks have passed the
    # compilation phase successfully before moving on to the
    # rest of the program. We think this might ensure that
    # the lock is released.
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>>> done with compiling and loading fused kernels. '
              'Compilation time: {:.3f} seconds'.format(
                  time.time() - start_time), flush=True)


Mohammad's avatar
Mohammad committed
151
152
153
154
155

def _initialize_distributed():
    """Initialize torch.distributed and mpu."""
    args = get_args()

Raul Puri's avatar
Raul Puri committed
156
    device_count = torch.cuda.device_count()
Mohammad's avatar
Mohammad committed
157
158
159
160
161
162
163
164
165
166
167
168
169
    if torch.distributed.is_initialized():

        if args.rank == 0:
            print('torch distributed is already initialized, '
                  'skipping initialization ...', flush=True)
        args.rank = torch.distributed.get_rank()
        args.world_size = torch.distributed.get_world_size()

    else:

        if args.rank == 0:
            print('> initializing torch distributed ...', flush=True)
        # Manually set the device ids.
170
        if device_count > 0:
Raul Puri's avatar
Raul Puri committed
171
            device = args.rank % device_count
172
173
174
175
176
177
            if args.local_rank is not None:
                assert args.local_rank == device, \
                    'expected local-rank to be the same as rank % device-count.'
            else:
                args.local_rank = device
            torch.cuda.set_device(device)
Mohammad's avatar
Mohammad committed
178
179
180
181
182
183
184
185
186
187
        # Call the init process
        init_method = 'tcp://'
        master_ip = os.getenv('MASTER_ADDR', 'localhost')
        master_port = os.getenv('MASTER_PORT', '6000')
        init_method += master_ip + ':' + master_port
        torch.distributed.init_process_group(
            backend=args.distributed_backend,
            world_size=args.world_size, rank=args.rank,
            init_method=init_method)

188
    # Set the tensor model-parallel, pipeline model-parallel, and
189
    # data-parallel communicators.
190
    if device_count > 0:
191
192
193
        if mpu.model_parallel_is_initialized():
            print('model parallel is already initialized')
        else:
194
            mpu.initialize_model_parallel(args.tensor_model_parallel_size,
195
196
                                          args.pipeline_model_parallel_size,
                                          args.virtual_pipeline_model_parallel_size)
Mohammad's avatar
Mohammad committed
197
198
199
200
201
202
203
204
205
206
207


def _init_autoresume():
    """Set autoresume start time."""
    autoresume = get_adlr_autoresume()
    if autoresume:
        torch.distributed.barrier()
        autoresume.init()
        torch.distributed.barrier()


208
def _set_random_seed(seed_):
Mohammad's avatar
Mohammad committed
209
    """Set random seed for reproducability."""
210
    if seed_ is not None and seed_ > 0:
211
        # Ensure that different pipeline MP stages get different seeds.
212
        seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
Mohammad's avatar
Mohammad committed
213
214
215
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
216
        if torch.cuda.device_count() > 0:
217
            mpu.model_parallel_cuda_manual_seed(seed)
Mohammad's avatar
Mohammad committed
218
219
    else:
        raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
Mohammad's avatar
Mohammad committed
220
221


222
def write_args_to_tensorboard():
Mohammad's avatar
Mohammad committed
223
224
225
226
227
    """Write arguments to tensorboard."""
    args = get_args()
    writer = get_tensorboard_writer()
    if writer:
        for arg in vars(args):
228
229
            writer.add_text(arg, str(getattr(args, arg)),
                            global_step=args.iteration)
230
231
232
233
234
235
236
237


def _initialize_mem_buffs():
    """Initialize manually allocated static memory."""
    args = get_args()

    # Initialize memory for checkpointed activations.
    if args.distribute_checkpointed_activations:
mohammad's avatar
mohammad committed
238
        mpu.init_checkpointed_activations_memory_buffer()