communication.py 7.71 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
mshoeybi's avatar
working  
mshoeybi committed
2
3
4
5
6
7

"""Communications utilities."""


import torch

xingjinliang's avatar
xingjinliang committed
8
from megatron.core import parallel_state
9
from megatron.core import mpu
mshoeybi's avatar
mshoeybi committed
10
11


mshoeybi's avatar
mshoeybi committed
12
# TODO: use functions from megatron/p2p
mshoeybi's avatar
mshoeybi committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def recv_from_prev_pipeline_rank_(recv_buffer=None):
    """Receive from previous pipeline stage and update the
    input buffer inplace."""
    if not mpu.is_pipeline_first_stage():
        assert recv_buffer is not None
        recv_prev_op = torch.distributed.P2POp(
            torch.distributed.irecv, recv_buffer,
            mpu.get_pipeline_model_parallel_prev_rank())
        reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
        for req in reqs:
            req.wait()
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()

mshoeybi's avatar
mshoeybi committed
27
# TODO: use functions from megatron/p2p
mshoeybi's avatar
mshoeybi committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def send_to_next_pipeline_rank(tensor=None):
    """Send output to the next pipeline stage."""
    if not mpu.is_pipeline_last_stage():
        assert tensor is not None
        send_next_op = torch.distributed.P2POp(
            torch.distributed.isend, tensor,
            mpu.get_pipeline_model_parallel_next_rank())
        reqs = torch.distributed.batch_isend_irecv([send_next_op])
        for req in reqs:
            req.wait()
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()



mshoeybi's avatar
mshoeybi committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def _is_cuda(tensor):
    """Check if a tensor is not none and is cuda."""
    assert tensor is not None
    assert tensor.is_cuda



def _is_cuda_contiguous(tensor):
    """Check if a tensor is not none, is cuda, and is contiguous."""
    _is_cuda(tensor)
    assert tensor.is_contiguous()



mshoeybi's avatar
mshoeybi committed
57
58
59
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
    """Broadcast a tensor from last pipeline stage to all ranks."""

mshoeybi's avatar
mshoeybi committed
60
61
62
63
64
65
66
67
    is_last_stage = mpu.is_pipeline_last_stage()
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if mpu.is_pipeline_first_stage() and is_last_stage:
        return tensor

    if is_last_stage:
        _is_cuda_contiguous(tensor)
mshoeybi's avatar
mshoeybi committed
68
69
70
71
72
73
74
75
76
77
78
79
    else:
        tensor = torch.empty(size,
                             dtype=dtype,
                             device=torch.cuda.current_device())
    # Get the group and corresponding source rank.
    src = mpu.get_pipeline_model_parallel_last_rank()
    group = mpu.get_pipeline_model_parallel_group()
    torch.distributed.broadcast(tensor, src, group)

    return tensor


xingjinliang's avatar
xingjinliang committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
def _send_and_recv_from_last_to_first_pipeline_stage(tensor=None):
    is_last_stage = mpu.is_pipeline_last_stage()
    is_first_stage = mpu.is_pipeline_first_stage()

    if is_last_stage or is_first_stage:
        if is_first_stage:
            recv_prev_op = torch.distributed.P2POp(
                torch.distributed.irecv, tensor,
                mpu.get_pipeline_model_parallel_last_rank())
            reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
        elif is_last_stage:
            send_next_op = torch.distributed.P2POp(
                torch.distributed.isend, tensor,
                mpu.get_pipeline_model_parallel_first_rank())
            reqs = torch.distributed.batch_isend_irecv([send_next_op])

        for req in reqs:
            req.wait()
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()

        return tensor

mshoeybi's avatar
working  
mshoeybi committed
103
104
105
106
107
108

def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    """Broadcast tensor values from last stage into the first stage."""

    is_last_stage = mpu.is_pipeline_last_stage()
    is_first_stage = mpu.is_pipeline_first_stage()
mshoeybi's avatar
mshoeybi committed
109
110
111
112
113
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if is_first_stage and is_last_stage:
        return tensor
    # Only first and last stage pipeline stages need to be involved.
mshoeybi's avatar
working  
mshoeybi committed
114
115
    if is_last_stage or is_first_stage:
        if is_last_stage:
mshoeybi's avatar
mshoeybi committed
116
            _is_cuda_contiguous(tensor)
mshoeybi's avatar
working  
mshoeybi committed
117
118
119
120
        else:
            tensor = torch.empty(size,
                                 dtype=dtype,
                                 device=torch.cuda.current_device())
xingjinliang's avatar
xingjinliang committed
121
        tensor = _send_and_recv_from_last_to_first_pipeline_stage(tensor)
mshoeybi's avatar
working  
mshoeybi committed
122
123
124
125
126
127
128
    else:
        tensor = None

    return tensor



mshoeybi's avatar
mshoeybi committed
129
130
131
132
133
134
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    """Copy tensor values from last stage into the first stage.
    Note that the input tensor is updated in place."""

    is_last_stage = mpu.is_pipeline_last_stage()
    is_first_stage = mpu.is_pipeline_first_stage()
mshoeybi's avatar
mshoeybi committed
135
136
137
138
139
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if is_first_stage and is_last_stage:
        return
    # Only first and last stage pipeline stages need to be involved.
mshoeybi's avatar
mshoeybi committed
140
    if is_last_stage or is_first_stage:
mshoeybi's avatar
mshoeybi committed
141
        _is_cuda(tensor)
mshoeybi's avatar
working  
mshoeybi committed
142
143
144
        is_contiguous = tensor.is_contiguous()
        if is_contiguous:
            tensor_ = tensor
mshoeybi's avatar
mshoeybi committed
145
        else:
mshoeybi's avatar
working  
mshoeybi committed
146
147
148
149
150
151
            if is_last_stage:
                tensor_ = tensor.contiguous()
            else:
                tensor_ = torch.empty(size,
                                      dtype=dtype,
                                      device=torch.cuda.current_device())
xingjinliang's avatar
xingjinliang committed
152
        tensor_ = _send_and_recv_from_last_to_first_pipeline_stage(tensor_)
mshoeybi's avatar
mshoeybi committed
153
        # Update the first stage tensor
mshoeybi's avatar
working  
mshoeybi committed
154
        if is_first_stage and not is_contiguous:
mshoeybi's avatar
mshoeybi committed
155
            tensor[...] = tensor_
mshoeybi's avatar
working  
mshoeybi committed
156
157


mshoeybi's avatar
mshoeybi committed
158

xingjinliang's avatar
xingjinliang committed
159
160
161
162
163
164
def broadcast_tensor(size, dtype, tensor=None, rank=0, data_parallel=False):
    """Given size and type of a tensor on all ranks and the tensor value
    only on a specific rank, broadcast from that rank to all other ranks.

    Args:
        data_parallel (bool): Broadcast across a single data parallel model replica.
mshoeybi's avatar
working  
mshoeybi committed
165
    """
xingjinliang's avatar
xingjinliang committed
166
167
    if data_parallel:
        rank = parallel_state.get_model_parallel_src_rank()
mshoeybi's avatar
working  
mshoeybi committed
168
169

    if torch.distributed.get_rank() == rank:
mshoeybi's avatar
mshoeybi committed
170
        _is_cuda_contiguous(tensor)
mshoeybi's avatar
working  
mshoeybi committed
171
172
173
174
175
    else:
        tensor = torch.empty(size,
                             dtype=dtype,
                             device=torch.cuda.current_device())

xingjinliang's avatar
xingjinliang committed
176
177
178
179
180
    group = None
    if data_parallel:
        group = parallel_state.get_model_parallel_group()

    torch.distributed.broadcast(tensor, rank, group=group)
mshoeybi's avatar
working  
mshoeybi committed
181
182
183
184

    return tensor


mshoeybi's avatar
mshoeybi committed
185

xingjinliang's avatar
xingjinliang committed
186
187
188
189
190
191
def broadcast_list(size, dtype, list_values=None, rank=0, data_parallel=False):
    """Broadcast a list of values with a given type.

    Args:
        data_parallel (bool): Broadcast across a single data parallel model replica.
    """
mshoeybi's avatar
mshoeybi committed
192
193
194

    tensor = None

xingjinliang's avatar
xingjinliang committed
195
196
197
198
    if data_parallel:
        if parallel_state.get_model_parallel_src_rank() == torch.distributed.get_rank():
            tensor = torch.tensor(list_values, dtype=dtype,
                                  device=torch.cuda.current_device())
mshoeybi's avatar
mshoeybi committed
199

xingjinliang's avatar
xingjinliang committed
200
201
202
203
204
        rank = parallel_state.get_model_parallel_src_rank()
    else:
        if torch.distributed.get_rank() == rank:
            tensor = torch.tensor(list_values, dtype=dtype,
                                  device=torch.cuda.current_device())
mshoeybi's avatar
mshoeybi committed
205

xingjinliang's avatar
xingjinliang committed
206
    return broadcast_tensor(size, dtype, tensor=tensor, rank=rank, data_parallel=data_parallel)
mshoeybi's avatar
mshoeybi committed
207

mshoeybi's avatar
working  
mshoeybi committed
208

mshoeybi's avatar
mshoeybi committed
209

xingjinliang's avatar
xingjinliang committed
210
211
212
213
214
215
216
217
def broadcast_int_list(size, int_list=None, rank=0, data_parallel=False):
    """Broadcast a list of integer values.

    Args:
        data_parallel (bool): Broadcast across a single data parallel model replica.
    """

    return broadcast_list(size, torch.int64, list_values=int_list, rank=rank, data_parallel=data_parallel)
mshoeybi's avatar
mshoeybi committed
218

mshoeybi's avatar
mshoeybi committed
219

xingjinliang's avatar
xingjinliang committed
220
221
222
223
224
225
226

def broadcast_float_list(size, float_list=None, rank=0, data_parallel=False):
    """Broadcast a list of float values.

    Args:
        data_parallel (bool): Broadcast across a single data parallel model replica.
    """
mshoeybi's avatar
working  
mshoeybi committed
227

mshoeybi's avatar
mshoeybi committed
228
    return broadcast_list(size, torch.float32, list_values=float_list,
xingjinliang's avatar
xingjinliang committed
229
                          rank=rank, data_parallel=data_parallel)