p2p_communication.py 16.7 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2

3
4
from functools import reduce
import operator
5
6
import torch

7
8
from megatron import get_args, core
from megatron.core import mpu
9
10


11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
def _communicate_shapes(tensor_send_next, tensor_send_prev,
                        recv_prev, recv_next):
    """Communicate tensor shapes between stages. Used to communicate 
    tensor shapes before the actual tensor communication happens.
    This is required when the sequence lengths across micro batches
    are not uniform.

    Takes the following arguments:
        tensor_send_next: tensor to send to next rank (no tensor sent if
                          set to None).
        tensor_send_prev: tensor to send to prev rank (no tensor sent if
                          set to None).
        recv_prev: boolean for whether tensor should be received from
                   previous rank.
        recv_next: boolean for whether tensor should be received from
                   next rank.
    Returns:
        (recv_prev_shape, recv_next_shape)
    """

    args = get_args()
    recv_prev_shape_tensor = None
    recv_next_shape_tensor = None
    send_prev_shape_tensor = None
    send_next_shape_tensor = None
    if recv_prev:
        recv_prev_shape_tensor = torch.empty((3),
                                             device=torch.cuda.current_device(),
                                             dtype=torch.int64)
    if recv_next:
        recv_next_shape_tensor = torch.empty((3),
                                             device=torch.cuda.current_device(),
                                             dtype=torch.int64)
    if tensor_send_prev is not None:
        send_prev_shape_tensor = torch.tensor(tensor_send_prev.size(),
                                              device=torch.cuda.current_device(),
                                              dtype=torch.int64)
    if tensor_send_next is not None:
        send_next_shape_tensor = torch.tensor(tensor_send_next.size(),
                                              device=torch.cuda.current_device(),
                                              dtype=torch.int64)

    if args.use_ring_exchange_p2p:
        torch.distributed.ring_exchange(tensor_send_prev=send_prev_shape_tensor,
                                        tensor_recv_prev=recv_prev_shape_tensor,
                                        tensor_send_next=send_next_shape_tensor,
                                        tensor_recv_next=recv_next_shape_tensor,
                                        group=mpu.get_pipeline_model_parallel_group())
    else:
        ops = []
        if send_prev_shape_tensor is not None:
            send_prev_op = torch.distributed.P2POp(
                torch.distributed.isend, send_prev_shape_tensor,
                mpu.get_pipeline_model_parallel_prev_rank())
            ops.append(send_prev_op)
        if recv_prev_shape_tensor is not None:
            recv_prev_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_prev_shape_tensor,
                mpu.get_pipeline_model_parallel_prev_rank())
            ops.append(recv_prev_op)
        if send_next_shape_tensor is not None:
            send_next_op = torch.distributed.P2POp(
                torch.distributed.isend, send_next_shape_tensor,
                mpu.get_pipeline_model_parallel_next_rank())
            ops.append(send_next_op)
        if recv_next_shape_tensor is not None:
            recv_next_op = torch.distributed.P2POp(
                torch.distributed.irecv, recv_next_shape_tensor,
                mpu.get_pipeline_model_parallel_next_rank())
            ops.append(recv_next_op)
        if len(ops) > 0:
            reqs = torch.distributed.batch_isend_irecv(ops)
            for req in reqs:
                req.wait()

        # To protect against race condition when using batch_isend_irecv().
        # should take this out once the bug with batch_isend_irecv is resolved.
        torch.cuda.synchronize()

    recv_prev_shape = [0, 0, 0]
    if recv_prev_shape_tensor is not None:
        recv_prev_shape = recv_prev_shape_tensor.tolist()

    recv_next_shape = [0, 0, 0]
    if recv_next_shape_tensor is not None:
        recv_next_shape = recv_next_shape_tensor.tolist()

    return recv_prev_shape, recv_next_shape


101
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
102
                 tensor_shape,
103
                 dtype_=None):
104
105
106
107
108
109
110
111
112
113
114
115
    """Communicate tensors between stages. Used as helper method in other
    communication methods that are used in megatron/schedules.py.

    Takes the following arguments:
        tensor_send_next: tensor to send to next rank (no tensor sent if
                          set to None).
        tensor_send_prev: tensor to send to prev rank (no tensor sent if
                          set to None).
        recv_prev: boolean for whether tensor should be received from
                   previous rank.
        recv_next: boolean for whether tensor should be received from
                   next rank.
116
117
118
119
120
        tensor_shape: shape of tensor to receive (this method assumes that all
                      tensors sent and received in a single function call are
                      the same shape).
        dtype_: optional, this is used when the tensor that needs to be
                communicated is different from args.params_dtype.
121
122
123
    Returns:
        (tensor_recv_prev, tensor_recv_next)
    """
124
125
126
127
128
129
    args = get_args()

    # Create placeholder tensors for receive in forward and backward directions
    # if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
130
131
132
133

    # Some legacy inference code doesn't set the tensor shape, do so now
    # for the normal values for gpt/bert. This could be removed if inference
    # code is changed to provide tensor_shape.
134
135
136
137
138
139
140
141
142
143
144
145
146
    if not args.variable_seq_lengths:
        if tensor_shape is None:
            recv_prev_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
            recv_next_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
        else:
            recv_prev_shape = tensor_shape
            recv_next_shape = tensor_shape
    else:
        recv_prev_shape, recv_next_shape = \
            _communicate_shapes(tensor_send_next,
                                tensor_send_prev,
                                recv_prev,
                                recv_next)
147

148
    override_scatter_gather_tensors_in_pipeline = False
149
    if args.scatter_gather_tensors_in_pipeline and \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
150
            not args.sequence_parallel:
151
152
153
154
155
156
157
        recv_prev_chunk_shape = reduce(operator.mul, recv_prev_shape, 1)
        recv_next_chunk_shape = reduce(operator.mul, recv_next_shape, 1)
        if recv_prev_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0 and \
                recv_next_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0:
            recv_prev_chunk_shape = recv_prev_chunk_shape // \
                mpu.get_tensor_model_parallel_world_size()
            recv_next_chunk_shape = recv_next_chunk_shape // \
158
159
                mpu.get_tensor_model_parallel_world_size()
        else:
160
161
            recv_prev_chunk_shape = recv_prev_shape
            recv_next_chunk_shape = recv_next_shape
162
            override_scatter_gather_tensors_in_pipeline = True
163
    else:
164
165
166
        recv_prev_chunk_shape = recv_prev_shape
        recv_next_chunk_shape = recv_next_shape

167
168
169
    dtype = args.params_dtype
    if args.fp32_residual_connection:
        dtype = torch.float
170
171
172
173
174
175

    requires_grad = True
    if dtype_ is not None:
        dtype = dtype_
        requires_grad = False

176
    if recv_prev:
177
        tensor_recv_prev = torch.empty(recv_prev_chunk_shape,
178
                                       requires_grad=requires_grad,
179
180
181
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)
    if recv_next:
182
        tensor_recv_next = torch.empty(recv_next_chunk_shape,
183
                                       requires_grad=requires_grad,
184
185
186
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)

187
    # Split tensor into smaller chunks if using scatter-gather optimization.
188
    if not override_scatter_gather_tensors_in_pipeline and \
189
            args.scatter_gather_tensors_in_pipeline and \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
190
            not args.sequence_parallel:
191
        if tensor_send_next is not None:
192
            tensor_send_next = core.tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_next)
193
194

        if tensor_send_prev is not None:
195
            tensor_send_prev = core.tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_prev)
196

197
    # Send tensors in both the forward and backward directions as appropriate.
198
    if args.use_ring_exchange_p2p:
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
                                        tensor_recv_prev=tensor_recv_prev,
                                        tensor_send_next=tensor_send_next,
                                        tensor_recv_next=tensor_recv_next,
                                        group=mpu.get_pipeline_model_parallel_group())
    else:
        ops = []
        if tensor_send_prev is not None:
            send_prev_op = torch.distributed.P2POp(
                torch.distributed.isend, tensor_send_prev,
                mpu.get_pipeline_model_parallel_prev_rank())
            ops.append(send_prev_op)
        if tensor_recv_prev is not None:
            recv_prev_op = torch.distributed.P2POp(
                torch.distributed.irecv, tensor_recv_prev,
                mpu.get_pipeline_model_parallel_prev_rank())
            ops.append(recv_prev_op)
        if tensor_send_next is not None:
            send_next_op = torch.distributed.P2POp(
                torch.distributed.isend, tensor_send_next,
                mpu.get_pipeline_model_parallel_next_rank())
            ops.append(send_next_op)
        if tensor_recv_next is not None:
            recv_next_op = torch.distributed.P2POp(
                torch.distributed.irecv, tensor_recv_next,
                mpu.get_pipeline_model_parallel_next_rank())
            ops.append(recv_next_op)
        if len(ops) > 0:
            reqs = torch.distributed.batch_isend_irecv(ops)
            for req in reqs:
                req.wait()
230
231
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()
232

233
    # If using scatter-gather optimization, gather smaller chunks.
234
    if not override_scatter_gather_tensors_in_pipeline and \
235
            args.scatter_gather_tensors_in_pipeline and \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
236
            not args.sequence_parallel:
237
        if recv_prev:
238
            tensor_recv_prev = core.tensor_parallel.gather_split_1d_tensor(
239
                tensor_recv_prev).view(recv_prev_shape).requires_grad_()
240
            tensor_recv_prev = core.utils.make_viewless_tensor(tensor_recv_prev,
241
242
                                                               requires_grad=True,
                                                               keep_graph=False)
243
244

        if recv_next:
245
            tensor_recv_next = core.tensor_parallel.gather_split_1d_tensor(
246
                tensor_recv_next).view(recv_next_shape).requires_grad_()
247
            tensor_recv_next = core.utils.make_viewless_tensor(tensor_recv_next,
248
249
                                                               requires_grad=True,
                                                               keep_graph=False)
250
251
252
253

    return tensor_recv_prev, tensor_recv_next


254
def recv_forward(tensor_shape=None, dtype_=None, timers=None):
255
    """Receive tensor from previous rank in pipeline (forward receive)."""
256

257
258
259
260
    if mpu.is_pipeline_first_stage():
        input_tensor = None
    else:
        if timers is not None:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
261
            timers('forward-recv', log_level=2).start()
262
263
264
265
        input_tensor, _ = _communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=True,
266
267
268
            recv_next=False,
            tensor_shape=tensor_shape,
            dtype_=dtype_)
269
270
271
272
273
        if timers is not None:
            timers('forward-recv').stop()
    return input_tensor


274
def recv_backward(tensor_shape=None, timers=None):
275
    """Receive tensor from next rank in pipeline (backward receive)."""
276
277
278
279
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
        if timers is not None:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
280
            timers('backward-recv', log_level=2).start()
281
282
283
284
        _, output_tensor_grad = _communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=False,
285
286
            recv_next=True,
            tensor_shape=tensor_shape)
287
288
289
290
291
        if timers is not None:
            timers('backward-recv').stop()
    return output_tensor_grad


292
def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
293
    """Send tensor to next rank in pipeline (forward send)."""
294

295
296
    if not mpu.is_pipeline_last_stage():
        if timers is not None:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
297
            timers('forward-send', log_level=2).start()
298
299
300
301
        _communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_prev=False,
302
            recv_next=False,
303
            tensor_shape=tensor_shape,
304
            dtype_=dtype_)
305
306
307
308
        if timers is not None:
            timers('forward-send').stop()


309
def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
310
    """Send tensor to previous rank in pipeline (backward send)."""
311
312
    if not mpu.is_pipeline_first_stage():
        if timers is not None:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
313
            timers('backward-send', log_level=2).start()
314
315
316
317
        _communicate(
            tensor_send_next=None,
            tensor_send_prev=input_tensor_grad,
            recv_prev=False,
318
319
            recv_next=False,
            tensor_shape=tensor_shape)
320
321
322
323
        if timers is not None:
            timers('backward-send').stop()


324
def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
325
    """Batched send and recv with next rank in pipeline."""
326
327
328
329
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
        if timers is not None:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
330
            timers('forward-send-backward-recv', log_level=2).start()
331
332
333
334
        _, output_tensor_grad = _communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_prev=False,
335
336
            recv_next=True,
            tensor_shape=tensor_shape)
337
338
339
340
341
        if timers is not None:
            timers('forward-send-backward-recv').stop()
    return output_tensor_grad


342
def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None):
343
    """Batched send and recv with previous rank in pipeline."""
344
345
346
347
    if mpu.is_pipeline_first_stage():
        input_tensor = None
    else:
        if timers is not None:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
348
            timers('backward-send-forward-recv', log_level=2).start()
349
350
351
352
        input_tensor, _ = _communicate(
            tensor_send_next=None,
            tensor_send_prev=input_tensor_grad,
            recv_prev=True,
353
354
            recv_next=False,
            tensor_shape=tensor_shape)
355
356
357
358
359
        if timers is not None:
            timers('backward-send-forward-recv').stop()
    return input_tensor


360
def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timers=None):
361
    """Batched recv from previous rank and send to next rank in pipeline."""
362
    if timers is not None:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
363
        timers('forward-send-forward-recv', log_level=2).start()
364
365
366
367
    input_tensor, _ = _communicate(
        tensor_send_next=output_tensor,
        tensor_send_prev=None,
        recv_prev=recv_prev,
368
369
        recv_next=False,
        tensor_shape=tensor_shape)
370
371
372
373
374
    if timers is not None:
        timers('forward-send-forward-recv').stop()
    return input_tensor


375
def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None, timers=None):
376
    """Batched recv from next rank and send to previous rank in pipeline."""
377
    if timers is not None:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
378
        timers('backward-send-backward-recv', log_level=2).start()
379
380
381
382
    _, output_tensor_grad = _communicate(
        tensor_send_next=None,
        tensor_send_prev=input_tensor_grad,
        recv_prev=False,
383
384
        recv_next=recv_next,
        tensor_shape=tensor_shape)
385
386
387
388
389
390
391
    if timers is not None:
        timers('backward-send-backward-recv').stop()
    return output_tensor_grad


def send_forward_backward_recv_forward_backward(
        output_tensor, input_tensor_grad, recv_prev,
392
        recv_next, tensor_shape=None, timers=None):
393
    """Batched send and recv with previous and next ranks in pipeline."""
394
    if timers is not None:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
395
396
        timers('forward-backward-send-forward-backward-recv',
               log_level=2).start()
397
398
399
400
    input_tensor, output_tensor_grad = _communicate(
        tensor_send_next=output_tensor,
        tensor_send_prev=input_tensor_grad,
        recv_prev=recv_prev,
401
402
        recv_next=recv_next,
        tensor_shape=tensor_shape)
403
404
405
    if timers is not None:
        timers('forward-backward-send-forward-backward-recv').stop()
    return input_tensor, output_tensor_grad