communication.py 5.88 KB
Newer Older
mshoeybi's avatar
working  
mshoeybi committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Communications utilities."""


import torch

mshoeybi's avatar
mshoeybi committed
21
22
23
from megatron import mpu


mshoeybi's avatar
mshoeybi committed
24

mshoeybi's avatar
mshoeybi committed
25
# TODO: use functions from megatron/p2p
mshoeybi's avatar
mshoeybi committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def recv_from_prev_pipeline_rank_(recv_buffer=None):
    """Receive from previous pipeline stage and update the
    input buffer inplace."""
    if not mpu.is_pipeline_first_stage():
        assert recv_buffer is not None
        recv_prev_op = torch.distributed.P2POp(
            torch.distributed.irecv, recv_buffer,
            mpu.get_pipeline_model_parallel_prev_rank())
        reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
        for req in reqs:
            req.wait()
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()



mshoeybi's avatar
mshoeybi committed
42
# TODO: use functions from megatron/p2p
mshoeybi's avatar
mshoeybi committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def send_to_next_pipeline_rank(tensor=None):
    """Send output to the next pipeline stage."""
    if not mpu.is_pipeline_last_stage():
        assert tensor is not None
        send_next_op = torch.distributed.P2POp(
            torch.distributed.isend, tensor,
            mpu.get_pipeline_model_parallel_next_rank())
        reqs = torch.distributed.batch_isend_irecv([send_next_op])
        for req in reqs:
            req.wait()
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()



mshoeybi's avatar
mshoeybi committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
    """Broadcast a tensor from last pipeline stage to all ranks."""

    if mpu.is_pipeline_last_stage():
        assert tensor is not None
        assert tensor.is_cuda
        assert tensor.is_contiguous()
    else:
        tensor = torch.empty(size,
                             dtype=dtype,
                             device=torch.cuda.current_device())
    # Get the group and corresponding source rank.
    src = mpu.get_pipeline_model_parallel_last_rank()
    group = mpu.get_pipeline_model_parallel_group()
    torch.distributed.broadcast(tensor, src, group)

    return tensor


mshoeybi's avatar
working  
mshoeybi committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    """Broadcast tensor values from last stage into the first stage."""

    # Only first and last stage pipeline stages need to be involved.
    is_last_stage = mpu.is_pipeline_last_stage()
    is_first_stage = mpu.is_pipeline_first_stage()
    if is_last_stage or is_first_stage:
        if is_last_stage:
            assert tensor is not None
            assert tensor.is_cuda
            assert tensor.is_contiguous()
        else:
            tensor = torch.empty(size,
                                 dtype=dtype,
                                 device=torch.cuda.current_device())
        src = mpu.get_pipeline_model_parallel_last_rank()
        group = mpu.get_embedding_group()
        # Broadcast from last stage into the first stage.
        torch.distributed.broadcast(tensor, src, group)
    else:
        tensor = None

    return tensor



mshoeybi's avatar
mshoeybi committed
104
105
106
107
108
109
110
111
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    """Copy tensor values from last stage into the first stage.
    Note that the input tensor is updated in place."""

    # Only first and last stage pipeline stages need to be involved.
    is_last_stage = mpu.is_pipeline_last_stage()
    is_first_stage = mpu.is_pipeline_first_stage()
    if is_last_stage or is_first_stage:
mshoeybi's avatar
working  
mshoeybi committed
112
113
114
        assert tensor is not None
        assert tensor.is_cuda
        is_contiguous = tensor.is_contiguous()
mshoeybi's avatar
mshoeybi committed
115
116
        src = mpu.get_pipeline_model_parallel_last_rank()
        group = mpu.get_embedding_group()
mshoeybi's avatar
working  
mshoeybi committed
117
118
        if is_contiguous:
            tensor_ = tensor
mshoeybi's avatar
mshoeybi committed
119
        else:
mshoeybi's avatar
working  
mshoeybi committed
120
121
122
123
124
125
            if is_last_stage:
                tensor_ = tensor.contiguous()
            else:
                tensor_ = torch.empty(size,
                                      dtype=dtype,
                                      device=torch.cuda.current_device())
mshoeybi's avatar
mshoeybi committed
126
127
128
        # Broadcast from last stage into the first stage.
        torch.distributed.broadcast(tensor_, src, group)
        # Update the first stage tensor
mshoeybi's avatar
working  
mshoeybi committed
129
        if is_first_stage and not is_contiguous:
mshoeybi's avatar
mshoeybi committed
130
            tensor[...] = tensor_
mshoeybi's avatar
working  
mshoeybi committed
131
132


mshoeybi's avatar
mshoeybi committed
133

mshoeybi's avatar
working  
mshoeybi committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def broadcast_tensor(size, dtype, tensor=None, rank=0):
    """ Given size and type of a tensor on all ranks and the tensor value
        only on a specific rank, broadcast from that rank to all other ranks.
    """

    if torch.distributed.get_rank() == rank:
        assert tensor is not None
        assert tensor.is_cuda
    else:
        tensor = torch.empty(size,
                             dtype=dtype,
                             device=torch.cuda.current_device())

    torch.distributed.broadcast(tensor, rank)

    return tensor


mshoeybi's avatar
mshoeybi committed
152

mshoeybi's avatar
mshoeybi committed
153
154
155
156
157
158
159
160
161
162
163
def broadcast_list(size, dtype, list_values=None, rank=0):
    """Broadcast a list of values with a given type."""

    tensor = None
    if torch.distributed.get_rank() == rank:
        tensor = torch.tensor(list_values, dtype=dtype,
                              device=torch.cuda.current_device())

    return broadcast_tensor(size, dtype, tensor=tensor, rank=rank)


mshoeybi's avatar
mshoeybi committed
164

mshoeybi's avatar
working  
mshoeybi committed
165
166
167
def broadcast_int_list(size, int_list=None, rank=0):
    """Broadcast a list of interger values."""

mshoeybi's avatar
mshoeybi committed
168
169
170
    return broadcast_list(size, torch.int64, list_values=int_list, rank=rank)


mshoeybi's avatar
mshoeybi committed
171

mshoeybi's avatar
mshoeybi committed
172
173
def broadcast_float_list(size, float_list=None, rank=0):
    """Broadcast a list of float values."""
mshoeybi's avatar
working  
mshoeybi committed
174

mshoeybi's avatar
mshoeybi committed
175
176
    return broadcast_list(size, torch.float32, list_values=float_list,
                          rank=rank)