initialize.py 18.9 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Model and data parallel groups."""

import torch

from .utils import ensure_divisibility


24
# Intra-layer model parallel group that the current rank belongs to.
25
_TENSOR_MODEL_PARALLEL_GROUP = None
26
# Inter-layer model parallel group that the current rank belongs to.
27
28
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
29
_MODEL_PARALLEL_GROUP = None
30
31
# Embedding group.
_EMBEDDING_GROUP = None
32
33
34
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None

35
36
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
37
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
38

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
39
# These values enable us to change the mpu sizes on the fly.
40
41
42
43
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
44

45
46
47
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None

48
# A list of global ranks for each pipeline group to ease calculation of the source
49
# rank when broadcasting from the first or last pipeline stage.
50
_PIPELINE_GLOBAL_RANKS = None
51

52
53
54
55
56
def is_unitialized():
    """Useful for code segments that may be accessed with or without mpu initialization"""
    return _DATA_PARALLEL_GROUP is None


57
def initialize_model_parallel(tensor_model_parallel_size_=1,
58
                              pipeline_model_parallel_size_=1,
59
60
                              virtual_pipeline_model_parallel_size_=None,
                              pipeline_model_parallel_split_rank_=None):
61
62
63
64
    """
    Initialize model data parallel groups.

    Arguments:
65
66
67
68
69
70
71
        tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
        virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
                                              pipeline).
        pipeline_model_parallel_split_rank: for models with both encoder and decoder,
                                            rank in pipeline with split point.

72
73

    Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
74
75
76
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
    create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
77
78
79
    and 8 data-parallel groups as:
        8 data_parallel groups:
            [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
80
        8 tensor model-parallel groups:
81
            [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
82
        4 pipeline model-parallel groups:
83
            [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
84
85
86
87
88
89
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    if torch.distributed.get_rank() == 0:
90
91
92
93
        print('> initializing tensor model parallel with size {}'.format(
            tensor_model_parallel_size_))
        print('> initializing pipeline model parallel with size {}'.format(
            pipeline_model_parallel_size_))
94
95
96
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size = torch.distributed.get_world_size()
97
98
    tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
    pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
99
    ensure_divisibility(world_size,
100
101
102
                        tensor_model_parallel_size * pipeline_model_parallel_size)
    data_parallel_size = world_size // (tensor_model_parallel_size *
                                        pipeline_model_parallel_size)
103

104
105
    num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
    num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
106
107
    num_data_parallel_groups = world_size // data_parallel_size

108
109
110
111
112
113
    if virtual_pipeline_model_parallel_size_ is not None:
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_

114
115
116
117
    if pipeline_model_parallel_split_rank_ is not None:
        global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
        _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_

118
119
    rank = torch.distributed.get_rank()

120
    # Build the data-parallel groups.
121
122
123
    global _DATA_PARALLEL_GROUP
    assert _DATA_PARALLEL_GROUP is None, \
        'data parallel group is already initialized'
124
    all_data_parallel_group_ranks = []
125
126
127
128
    for i in range(pipeline_model_parallel_size):
        start_rank = i * num_pipeline_model_parallel_groups
        end_rank = (i + 1) * num_pipeline_model_parallel_groups
        for j in range(tensor_model_parallel_size):
129
            ranks = range(start_rank + j, end_rank,
130
                          tensor_model_parallel_size)
131
132
133
134
135
136
            all_data_parallel_group_ranks.append(list(ranks))
            group = torch.distributed.new_group(ranks)
            if rank in ranks:
                _DATA_PARALLEL_GROUP = group

    # Build the model-parallel groups.
137
138
139
    global _MODEL_PARALLEL_GROUP
    assert _MODEL_PARALLEL_GROUP is None, \
        'model parallel group is already initialized'
140
141
142
    for i in range(data_parallel_size):
        ranks = [data_parallel_group_ranks[i]
                 for data_parallel_group_ranks in all_data_parallel_group_ranks]
143
        group = torch.distributed.new_group(ranks)
144
        if rank in ranks:
145
146
            _MODEL_PARALLEL_GROUP = group

147
148
149
150
151
152
153
    # Build the tensor model-parallel groups.
    global _TENSOR_MODEL_PARALLEL_GROUP
    assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
        'tensor model parallel group is already initialized'
    for i in range(num_tensor_model_parallel_groups):
        ranks = range(i * tensor_model_parallel_size,
                      (i + 1) * tensor_model_parallel_size)
154
155
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
156
            _TENSOR_MODEL_PARALLEL_GROUP = group
157

158
159
160
    # Build the pipeline model-parallel groups and embedding groups
    # (first and last rank in each pipeline model-parallel group).
    global _PIPELINE_MODEL_PARALLEL_GROUP
161
    global _PIPELINE_GLOBAL_RANKS
162
163
    assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
        'pipeline model parallel group is already initialized'
164
    global _EMBEDDING_GROUP
165
    global _EMBEDDING_GLOBAL_RANKS
166
167
    assert _EMBEDDING_GROUP is None, \
        'embedding group is already initialized'
168
    for i in range(num_pipeline_model_parallel_groups):
169
        ranks = range(i, world_size,
170
                      num_pipeline_model_parallel_groups)
171
172
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
173
            _PIPELINE_MODEL_PARALLEL_GROUP = group
174
            _PIPELINE_GLOBAL_RANKS = ranks
175
176
177
178
        # Setup embedding group (to exchange gradients between
        # first and last stages).
        if len(ranks) > 1:
            embedding_ranks = [ranks[0], ranks[-1]]
179
            if pipeline_model_parallel_split_rank_ is not None and \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
180
                    ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks:
181
182
183
                embedding_ranks = [ranks[0],
                                   ranks[pipeline_model_parallel_split_rank_],
                                   ranks[-1]]
184
185
186
187
188
        else:
            embedding_ranks = ranks
        group = torch.distributed.new_group(embedding_ranks)
        if rank in embedding_ranks:
            _EMBEDDING_GROUP = group
189
190
        if rank in ranks:
            _EMBEDDING_GLOBAL_RANKS = embedding_ranks
191

192
193
194

def model_parallel_is_initialized():
    """Check if model and data parallel groups are initialized."""
195
196
    if _TENSOR_MODEL_PARALLEL_GROUP is None or \
        _PIPELINE_MODEL_PARALLEL_GROUP is None or \
197
        _DATA_PARALLEL_GROUP is None:
198
199
200
201
202
203
204
205
206
207
208
        return False
    return True


def get_model_parallel_group():
    """Get the model parallel group the caller rank belongs to."""
    assert _MODEL_PARALLEL_GROUP is not None, \
        'model parallel group is not initialized'
    return _MODEL_PARALLEL_GROUP


209
210
211
def get_tensor_model_parallel_group():
    """Get the tensor model parallel group the caller rank belongs to."""
    assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
212
        'intra_layer_model parallel group is not initialized'
213
    return _TENSOR_MODEL_PARALLEL_GROUP
214
215


216
217
218
219
220
def get_pipeline_model_parallel_group():
    """Get the pipeline model parallel group the caller rank belongs to."""
    assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
        'pipeline_model parallel group is not initialized'
    return _PIPELINE_MODEL_PARALLEL_GROUP
221
222


223
224
225
226
227
228
229
def get_data_parallel_group():
    """Get the data parallel group the caller rank belongs to."""
    assert _DATA_PARALLEL_GROUP is not None, \
        'data parallel group is not initialized'
    return _DATA_PARALLEL_GROUP


230
231
232
233
234
235
236
def get_embedding_group():
    """Get the embedding group the caller rank belongs to."""
    assert _EMBEDDING_GROUP is not None, \
        'embedding group is not initialized'
    return _EMBEDDING_GROUP


237
238
239
240
def set_tensor_model_parallel_world_size(world_size):
    """Set the tensor model parallel size"""
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
241
242


243
244
245
246
def set_pipeline_model_parallel_world_size(world_size):
    """Set the pipeline model parallel size"""
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
247
248


249
250
251
252
253
254
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
255
256


257
258
259
260
261
262
def get_pipeline_model_parallel_world_size():
    """Return world size for the pipeline model parallel group."""
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
263
264


265
266
267
268
def set_tensor_model_parallel_rank(rank):
    """Set tensor model parallel rank."""
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    _MPU_TENSOR_MODEL_PARALLEL_RANK = rank
269
270


271
272
273
274
def set_pipeline_model_parallel_rank(rank):
    """Set pipeline model parallel rank."""
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
275
276


277
278
279
280
281
282
def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
283
284


285
286
287
288
289
290
def get_pipeline_model_parallel_rank():
    """Return my rank for the pipeline model parallel group."""
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
        return _MPU_PIPELINE_MODEL_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
291
292


293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
def get_num_layers(args, is_encoder_and_decoder_model):
    """Compute the number of transformer layers resident on the current rank."""
    if get_pipeline_model_parallel_world_size() > 1:
        if is_encoder_and_decoder_model:
            assert args.pipeline_model_parallel_split_rank is not None
            num_ranks_in_encoder = args.pipeline_model_parallel_split_rank
            num_ranks_in_decoder = get_pipeline_model_parallel_world_size() - num_ranks_in_encoder
            assert args.num_layers % num_ranks_in_encoder == 0, \
                    'num_layers must be divisible by number of ranks given to encoder'
            assert args.num_layers % num_ranks_in_decoder == 0, \
                    'num_layers must be divisible by number of ranks given to decoder'
            if is_pipeline_stage_before_split():
                num_layers = args.num_layers // num_ranks_in_encoder
            else:
                num_layers = args.num_layers // num_ranks_in_decoder
        else:
            assert args.num_layers % get_pipeline_model_parallel_world_size() == 0, \
                'num_layers must be divisible by pipeline_model_parallel_size'
            num_layers = args.num_layers // get_pipeline_model_parallel_world_size()
    else:
        num_layers = args.num_layers
    return num_layers


317
def is_pipeline_first_stage(ignore_virtual=False):
318
    """Return True if in the first pipeline model-parallel stage, False otherwise."""
319
    if not ignore_virtual:
320
321
        if get_virtual_pipeline_model_parallel_world_size() is not None and \
            get_virtual_pipeline_model_parallel_rank() != 0:
322
            return False
323
    return get_pipeline_model_parallel_rank() == 0
324
325


326
def is_pipeline_last_stage(ignore_virtual=False):
327
    """Return True if in the last pipeline model-parallel stage, False otherwise."""
328
    if not ignore_virtual:
329
330
331
332
333
        virtual_pipeline_model_parallel_world_size = \
            get_virtual_pipeline_model_parallel_world_size()
        if virtual_pipeline_model_parallel_world_size is not None and \
            get_virtual_pipeline_model_parallel_rank() != (
                virtual_pipeline_model_parallel_world_size - 1):
334
            return False
335
336
    return get_pipeline_model_parallel_rank() == (
        get_pipeline_model_parallel_world_size() - 1)
337
338


339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
def is_rank_in_embedding_group(ignore_virtual=False):
    """Return true if current rank is in embedding group, False otherwise."""
    rank = torch.distributed.get_rank()
    global _EMBEDDING_GLOBAL_RANKS
    if ignore_virtual:
        return rank in _EMBEDDING_GLOBAL_RANKS
    if rank in _EMBEDDING_GLOBAL_RANKS:
        if rank == _EMBEDDING_GLOBAL_RANKS[0]:
            return is_pipeline_first_stage(ignore_virtual=False)
        elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
            return is_pipeline_last_stage(ignore_virtual=False)
        else:
            return True
    return False


def is_pipeline_stage_before_split(rank=None):
    """Return True if pipeline stage executes encoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
        return True
    if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
        return True
    return False


def is_pipeline_stage_after_split(rank=None):
    """Return True if pipeline stage executes decoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
        return True
    if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
        return True
    return False


def is_pipeline_stage_at_split():
    """Return true if pipeline stage executes decoder block and next
    stage executes encoder block for a model with both encoder and
    decoder."""
    rank = get_pipeline_model_parallel_rank()
    return is_pipeline_stage_before_split(rank) and \
            is_pipeline_stage_after_split(rank+1)


394
395
396
397
398
399
400
401
402
403
404
405
def get_virtual_pipeline_model_parallel_rank():
    """Return the virtual pipeline-parallel rank."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK


def set_virtual_pipeline_model_parallel_rank(rank):
    """Set the virtual pipeline-parallel rank."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank


406
407
408
409
410
411
def get_virtual_pipeline_model_parallel_world_size():
    """Return the virtual pipeline-parallel world size."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE


412
def get_tensor_model_parallel_src_rank():
413
    """Calculate the global rank corresponding to the first local rank
414
    in the tensor model parallel group."""
415
    global_rank = torch.distributed.get_rank()
416
    local_world_size = get_tensor_model_parallel_world_size()
417
418
    return (global_rank // local_world_size) * local_world_size

419

420
421
422
423
424
def get_pipeline_model_parallel_first_rank():
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    return _PIPELINE_GLOBAL_RANKS[0]

425

426
427
428
429
430
def get_pipeline_model_parallel_last_rank():
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    last_rank_local = get_pipeline_model_parallel_world_size() - 1
    return _PIPELINE_GLOBAL_RANKS[last_rank_local]
431

432
def get_pipeline_model_parallel_next_rank():
433
434
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
435
436
437
438
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]

439

440
441
442
443
444
445
def get_pipeline_model_parallel_prev_rank():
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
446

447

448
449
450
451
452
453
454
455
456
457
458
459
def get_data_parallel_world_size():
    """Return world size for the data parallel group."""
    return torch.distributed.get_world_size(group=get_data_parallel_group())


def get_data_parallel_rank():
    """Return my rank for the data parallel group."""
    return torch.distributed.get_rank(group=get_data_parallel_group())


def destroy_model_parallel():
    """Set the groups to none."""
460
461
    global _MODEL_PARALLEL_GROUP
    _MODEL_PARALLEL_GROUP = None
462
463
464
465
    global _TENSOR_MODEL_PARALLEL_GROUP
    _TENSOR_MODEL_PARALLEL_GROUP = None
    global _PIPELINE_MODEL_PARALLEL_GROUP
    _PIPELINE_MODEL_PARALLEL_GROUP = None
466
467
    global _DATA_PARALLEL_GROUP
    _DATA_PARALLEL_GROUP = None
468
469
    global _EMBEDDING_GROUP
    _EMBEDDING_GROUP = None