initialize.py 6.24 KB
Newer Older
Mohammad's avatar
Mohammad committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Mohammad's avatar
Mohammad committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron initialization."""

import random
import os

21
import numpy as np
Mohammad's avatar
Mohammad committed
22
23
import torch

24
25
26
from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
Mohammad's avatar
Mohammad committed
27
from megatron import mpu
28
from megatron.global_vars import set_global_variables
29
from megatron.mpu import set_tensor_model_parallel_rank, set_tensor_model_parallel_world_size
Mohammad's avatar
Mohammad committed
30

31
def initialize_megatron(extra_args_provider=None, args_defaults={},
Raul Puri's avatar
Raul Puri committed
32
                        ignore_unknown_args=False, allow_no_cuda=False):
Mohammad's avatar
Mohammad committed
33
    """Set global variables, initialize distributed, and
Raul Puri's avatar
Raul Puri committed
34
35
36
    set autoresume and random seeds.
    `allow_no_cuda` should not be set unless using megatron for cpu only 
    data processing. In general this arg should not be set unless you know 
37
38
    what you are doing.
    Returns a function to finalize distributed env initialization 
Boris Fomitchev's avatar
Boris Fomitchev committed
39
    (optionally, only when args.lazy_mpu_init == True)
40
41

"""
Raul Puri's avatar
Raul Puri committed
42
43
44
    if not allow_no_cuda:
        # Make sure cuda is available.
        assert torch.cuda.is_available(), 'Megatron requires CUDA.'
Mohammad's avatar
Mohammad committed
45

Mohammad's avatar
Mohammad committed
46
47
    # Parse args, build tokenizer, and set adlr-autoresume,
    # tensorboard-writer, and timers.
Mohammad's avatar
Mohammad committed
48
    set_global_variables(extra_args_provider=extra_args_provider,
49
50
                         args_defaults=args_defaults,
                         ignore_unknown_args=ignore_unknown_args)
Mohammad's avatar
Mohammad committed
51

52
    # torch.distributed initialization
53
    def finish_mpu_init():
54
55
56
57
58
59
60
61
        args = get_args()
        # Pytorch distributed.
        _initialize_distributed()
        
        # Random seeds for reproducibility.
        if args.rank == 0:
            print('> setting random seeds to {} ...'.format(args.seed))
        _set_random_seed(args.seed)
Mohammad's avatar
Mohammad committed
62
63

    args = get_args()
64
    if  args.lazy_mpu_init:
65
        args.use_cpu_initialization=True
66
67
        # delayed initialization of DDP-related stuff
        # We only set basic DDP globals    
68
        set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
Boris Fomitchev's avatar
Boris Fomitchev committed
69
        # and return function for external DDP manager to call when it has DDP initialized
70
        set_tensor_model_parallel_rank(args.rank)    
71
        return finish_mpu_init
72
    else:
73
74
        # Megatron's MPU is the master. Complete initialization right away.
        finish_mpu_init()
75
76
77

        # Initialize memory buffers.
        _initialize_mem_buffs()
78
79
80
        
        # Autoresume.
        _init_autoresume()
mshoeybi's avatar
mshoeybi committed
81
82

        # Compile dataset C++ code.
83
84
85
86
87
        if torch.distributed.get_rank() == 0:
            from megatron.data.dataset_utils import compile_helper
            compile_helper()
        # Simple barrier
        torch.distributed.barrier()
88
89
90
91
        
        # No continuation function
        return None
        
Mohammad's avatar
Mohammad committed
92
93
94
95
96

def _initialize_distributed():
    """Initialize torch.distributed and mpu."""
    args = get_args()

Raul Puri's avatar
Raul Puri committed
97
    device_count = torch.cuda.device_count()
Mohammad's avatar
Mohammad committed
98
99
100
101
102
103
104
105
106
107
108
109
110
    if torch.distributed.is_initialized():

        if args.rank == 0:
            print('torch distributed is already initialized, '
                  'skipping initialization ...', flush=True)
        args.rank = torch.distributed.get_rank()
        args.world_size = torch.distributed.get_world_size()

    else:

        if args.rank == 0:
            print('> initializing torch distributed ...', flush=True)
        # Manually set the device ids.
111
        if device_count > 0:
Raul Puri's avatar
Raul Puri committed
112
            device = args.rank % device_count
113
114
115
116
117
118
            if args.local_rank is not None:
                assert args.local_rank == device, \
                    'expected local-rank to be the same as rank % device-count.'
            else:
                args.local_rank = device
            torch.cuda.set_device(device)
Mohammad's avatar
Mohammad committed
119
120
121
122
123
124
125
126
127
128
        # Call the init process
        init_method = 'tcp://'
        master_ip = os.getenv('MASTER_ADDR', 'localhost')
        master_port = os.getenv('MASTER_PORT', '6000')
        init_method += master_ip + ':' + master_port
        torch.distributed.init_process_group(
            backend=args.distributed_backend,
            world_size=args.world_size, rank=args.rank,
            init_method=init_method)

129
    # Set the tensor model-parallel, pipeline model-parallel, and
130
    # data-parallel communicators.
131
    if device_count > 0:
132
133
134
        if mpu.model_parallel_is_initialized():
            print('model parallel is already initialized')
        else:
135
136
            mpu.initialize_model_parallel(args.tensor_model_parallel_size,
                                          args.pipeline_model_parallel_size)
Mohammad's avatar
Mohammad committed
137
138
139
140
141
142
143
144
145
146
147


def _init_autoresume():
    """Set autoresume start time."""
    autoresume = get_adlr_autoresume()
    if autoresume:
        torch.distributed.barrier()
        autoresume.init()
        torch.distributed.barrier()


148
def _set_random_seed(seed_):
Mohammad's avatar
Mohammad committed
149
    """Set random seed for reproducability."""
150
    if seed_ is not None and seed_ > 0:
151
        # Ensure that different pipeline MP stages get different seeds.
152
        seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
Mohammad's avatar
Mohammad committed
153
154
155
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
156
        if torch.cuda.device_count() > 0:
157
            mpu.model_parallel_cuda_manual_seed(seed)
Mohammad's avatar
Mohammad committed
158
159
    else:
        raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
Mohammad's avatar
Mohammad committed
160
161


162
def write_args_to_tensorboard():
Mohammad's avatar
Mohammad committed
163
164
165
166
167
    """Write arguments to tensorboard."""
    args = get_args()
    writer = get_tensorboard_writer()
    if writer:
        for arg in vars(args):
168
169
            writer.add_text(arg, str(getattr(args, arg)),
                            global_step=args.iteration)
170
171
172
173
174
175
176
177


def _initialize_mem_buffs():
    """Initialize manually allocated static memory."""
    args = get_args()

    # Initialize memory for checkpointed activations.
    if args.distribute_checkpointed_activations:
mohammad's avatar
mohammad committed
178
        mpu.init_checkpointed_activations_memory_buffer()