"vscode:/vscode.git/clone" did not exist on "e782eb7e6a9af1e4b81adf7459f737b29fb388ea"
p2p_communication.py 11.7 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2

3
4
from functools import reduce
import operator
5
6
7
8
9
10
import torch

from megatron import get_args
from megatron import mpu


11
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
12
                 tensor_shape,
13
                 dtype_=None):
14
15
16
17
18
19
20
21
22
23
24
25
    """Communicate tensors between stages. Used as helper method in other
    communication methods that are used in megatron/schedules.py.

    Takes the following arguments:
        tensor_send_next: tensor to send to next rank (no tensor sent if
                          set to None).
        tensor_send_prev: tensor to send to prev rank (no tensor sent if
                          set to None).
        recv_prev: boolean for whether tensor should be received from
                   previous rank.
        recv_next: boolean for whether tensor should be received from
                   next rank.
26
27
28
29
30
        tensor_shape: shape of tensor to receive (this method assumes that all
                      tensors sent and received in a single function call are
                      the same shape).
        dtype_: optional, this is used when the tensor that needs to be
                communicated is different from args.params_dtype.
31
32
33
    Returns:
        (tensor_recv_prev, tensor_recv_next)
    """
34
35
36
37
38
39
    args = get_args()

    # Create placeholder tensors for receive in forward and backward directions
    # if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
40
41
42
43
44
45
46

    # Some legacy inference code doesn't set the tensor shape, do so now
    # for the normal values for gpt/bert. This could be removed if inference
    # code is changed to provide tensor_shape.
    if tensor_shape is None:
        tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)

47
    override_scatter_gather_tensors_in_pipeline = False
48
    if args.scatter_gather_tensors_in_pipeline and \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
49
            not args.sequence_parallel:
50
51
52
53
54
55
56
        tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
        if tensor_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0:
            tensor_chunk_shape = tensor_chunk_shape // \
                mpu.get_tensor_model_parallel_world_size()
        else:
            tensor_chunk_shape = tensor_shape
            override_scatter_gather_tensors_in_pipeline = True
57
58
    else:
        tensor_chunk_shape = tensor_shape
59
60
61
    dtype = args.params_dtype
    if args.fp32_residual_connection:
        dtype = torch.float
62
63
64
65
66
67

    requires_grad = True
    if dtype_ is not None:
        dtype = dtype_
        requires_grad = False

68
    if recv_prev:
69
        tensor_recv_prev = torch.empty(tensor_chunk_shape,
70
                                       requires_grad=requires_grad,
71
72
73
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)
    if recv_next:
74
        tensor_recv_next = torch.empty(tensor_chunk_shape,
75
                                       requires_grad=requires_grad,
76
77
78
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)

79
    # Split tensor into smaller chunks if using scatter-gather optimization.
80
    if not override_scatter_gather_tensors_in_pipeline and \
81
            args.scatter_gather_tensors_in_pipeline and \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
82
            not args.sequence_parallel:
83
84
85
86
87
88
        if tensor_send_next is not None:
            tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)

        if tensor_send_prev is not None:
            tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)

89
    # Send tensors in both the forward and backward directions as appropriate.
90
    if args.use_ring_exchange_p2p:
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
                                        tensor_recv_prev=tensor_recv_prev,
                                        tensor_send_next=tensor_send_next,
                                        tensor_recv_next=tensor_recv_next,
                                        group=mpu.get_pipeline_model_parallel_group())
    else:
        ops = []
        if tensor_send_prev is not None:
            send_prev_op = torch.distributed.P2POp(
                torch.distributed.isend, tensor_send_prev,
                mpu.get_pipeline_model_parallel_prev_rank())
            ops.append(send_prev_op)
        if tensor_recv_prev is not None:
            recv_prev_op = torch.distributed.P2POp(
                torch.distributed.irecv, tensor_recv_prev,
                mpu.get_pipeline_model_parallel_prev_rank())
            ops.append(recv_prev_op)
        if tensor_send_next is not None:
            send_next_op = torch.distributed.P2POp(
                torch.distributed.isend, tensor_send_next,
                mpu.get_pipeline_model_parallel_next_rank())
            ops.append(send_next_op)
        if tensor_recv_next is not None:
            recv_next_op = torch.distributed.P2POp(
                torch.distributed.irecv, tensor_recv_next,
                mpu.get_pipeline_model_parallel_next_rank())
            ops.append(recv_next_op)
        if len(ops) > 0:
            reqs = torch.distributed.batch_isend_irecv(ops)
            for req in reqs:
                req.wait()
122
123
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()
124

125
    # If using scatter-gather optimization, gather smaller chunks.
126
    if not override_scatter_gather_tensors_in_pipeline and \
127
            args.scatter_gather_tensors_in_pipeline and \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
128
            not args.sequence_parallel:
129
130
131
        if recv_prev:
            tensor_recv_prev = mpu.gather_split_1d_tensor(
                tensor_recv_prev).view(tensor_shape).requires_grad_()
Lawrence McAfee's avatar
Lawrence McAfee committed
132
133
134
            tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev,
                                                        requires_grad = True,
                                                        keep_graph = False)
135
136
137
138

        if recv_next:
            tensor_recv_next = mpu.gather_split_1d_tensor(
                tensor_recv_next).view(tensor_shape).requires_grad_()
Lawrence McAfee's avatar
Lawrence McAfee committed
139
140
141
            tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next,
                                                        requires_grad = True,
                                                        keep_graph = False)
142
143
144
145

    return tensor_recv_prev, tensor_recv_next


146
def recv_forward(tensor_shape=None, dtype_=None, timers=None):
147
    """Receive tensor from previous rank in pipeline (forward receive)."""
148

149
150
151
152
    if mpu.is_pipeline_first_stage():
        input_tensor = None
    else:
        if timers is not None:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
153
            timers('forward-recv', log_level=2).start()
154
155
156
157
        input_tensor, _ = _communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=True,
158
159
160
            recv_next=False,
            tensor_shape=tensor_shape,
            dtype_=dtype_)
161
162
163
164
165
        if timers is not None:
            timers('forward-recv').stop()
    return input_tensor


166
def recv_backward(tensor_shape=None, timers=None):
167
    """Receive tensor from next rank in pipeline (backward receive)."""
168
169
170
171
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
        if timers is not None:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
172
            timers('backward-recv', log_level=2).start()
173
174
175
176
        _, output_tensor_grad = _communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=False,
177
178
            recv_next=True,
            tensor_shape=tensor_shape)
179
180
181
182
183
        if timers is not None:
            timers('backward-recv').stop()
    return output_tensor_grad


184
def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
185
    """Send tensor to next rank in pipeline (forward send)."""
186

187
188
    if not mpu.is_pipeline_last_stage():
        if timers is not None:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
189
            timers('forward-send', log_level=2).start()
190
191
192
193
        _communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_prev=False,
194
            recv_next=False,
195
            tensor_shape=tensor_shape,
196
            dtype_=dtype_)
197
198
199
200
        if timers is not None:
            timers('forward-send').stop()


201
def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
202
    """Send tensor to previous rank in pipeline (backward send)."""
203
204
    if not mpu.is_pipeline_first_stage():
        if timers is not None:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
205
            timers('backward-send', log_level=2).start()
206
207
208
209
        _communicate(
            tensor_send_next=None,
            tensor_send_prev=input_tensor_grad,
            recv_prev=False,
210
211
            recv_next=False,
            tensor_shape=tensor_shape)
212
213
214
215
        if timers is not None:
            timers('backward-send').stop()


216
def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
217
    """Batched send and recv with next rank in pipeline."""
218
219
220
221
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
        if timers is not None:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
222
            timers('forward-send-backward-recv', log_level=2).start()
223
224
225
226
        _, output_tensor_grad = _communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_prev=False,
227
228
            recv_next=True,
            tensor_shape=tensor_shape)
229
230
231
232
233
        if timers is not None:
            timers('forward-send-backward-recv').stop()
    return output_tensor_grad


234
def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None):
235
    """Batched send and recv with previous rank in pipeline."""
236
237
238
239
    if mpu.is_pipeline_first_stage():
        input_tensor = None
    else:
        if timers is not None:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
240
            timers('backward-send-forward-recv', log_level=2).start()
241
242
243
244
        input_tensor, _ = _communicate(
            tensor_send_next=None,
            tensor_send_prev=input_tensor_grad,
            recv_prev=True,
245
246
            recv_next=False,
            tensor_shape=tensor_shape)
247
248
249
250
251
        if timers is not None:
            timers('backward-send-forward-recv').stop()
    return input_tensor


252
def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timers=None):
253
    """Batched recv from previous rank and send to next rank in pipeline."""
254
    if timers is not None:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
255
        timers('forward-send-forward-recv', log_level=2).start()
256
257
258
259
    input_tensor, _ = _communicate(
        tensor_send_next=output_tensor,
        tensor_send_prev=None,
        recv_prev=recv_prev,
260
261
        recv_next=False,
        tensor_shape=tensor_shape)
262
263
264
265
266
    if timers is not None:
        timers('forward-send-forward-recv').stop()
    return input_tensor


267
def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None, timers=None):
268
    """Batched recv from next rank and send to previous rank in pipeline."""
269
    if timers is not None:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
270
        timers('backward-send-backward-recv', log_level=2).start()
271
272
273
274
    _, output_tensor_grad = _communicate(
        tensor_send_next=None,
        tensor_send_prev=input_tensor_grad,
        recv_prev=False,
275
276
        recv_next=recv_next,
        tensor_shape=tensor_shape)
277
278
279
280
281
282
283
    if timers is not None:
        timers('backward-send-backward-recv').stop()
    return output_tensor_grad


def send_forward_backward_recv_forward_backward(
        output_tensor, input_tensor_grad, recv_prev,
284
        recv_next, tensor_shape=None, timers=None):
285
    """Batched send and recv with previous and next ranks in pipeline."""
286
    if timers is not None:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
287
288
        timers('forward-backward-send-forward-backward-recv',
               log_level=2).start()
289
290
291
292
    input_tensor, output_tensor_grad = _communicate(
        tensor_send_next=output_tensor,
        tensor_send_prev=input_tensor_grad,
        recv_prev=recv_prev,
293
294
        recv_next=recv_next,
        tensor_shape=tensor_shape)
295
296
297
    if timers is not None:
        timers('forward-backward-send-forward-backward-recv').stop()
    return input_tensor, output_tensor_grad