p2p_communication.py 12.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
17
from functools import reduce
import operator
18
19
20
21
22
import torch

from megatron import get_args
from megatron import mpu

23
24
25
# >>>
from megatron.mpu.random import make_viewless_tensor
# <<<
26

27
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
28
29
                 tensor_shape,
                 use_ring_exchange=False,
30
                 dtype_=None):
31
32
33
34
35
36
37
38
39
40
41
42
    """Communicate tensors between stages. Used as helper method in other
    communication methods that are used in megatron/schedules.py.

    Takes the following arguments:
        tensor_send_next: tensor to send to next rank (no tensor sent if
                          set to None).
        tensor_send_prev: tensor to send to prev rank (no tensor sent if
                          set to None).
        recv_prev: boolean for whether tensor should be received from
                   previous rank.
        recv_next: boolean for whether tensor should be received from
                   next rank.
43
44
45
        tensor_shape: shape of tensor to receive (this method assumes that all
                      tensors sent and received in a single function call are
                      the same shape).
46
47
        use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
                           API should be used.
48
49
        dtype_: optional, this is used when the tensor that needs to be
                communicated is different from args.params_dtype.
50
51
52
    Returns:
        (tensor_recv_prev, tensor_recv_next)
    """
53
54
55
56
57
58
    args = get_args()

    # Create placeholder tensors for receive in forward and backward directions
    # if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
59
60
61
62
63
64
65

    # Some legacy inference code doesn't set the tensor shape, do so now
    # for the normal values for gpt/bert. This could be removed if inference
    # code is changed to provide tensor_shape.
    if tensor_shape is None:
        tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)

66
67
68
69
70
71
72
73
74
    override_scatter_gather_tensors_in_pipeline = False
    if args.scatter_gather_tensors_in_pipeline:
        tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
        if tensor_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0:
            tensor_chunk_shape = tensor_chunk_shape // \
                mpu.get_tensor_model_parallel_world_size()
        else:
            tensor_chunk_shape = tensor_shape
            override_scatter_gather_tensors_in_pipeline = True
75
76
    else:
        tensor_chunk_shape = tensor_shape
77
78
79
    dtype = args.params_dtype
    if args.fp32_residual_connection:
        dtype = torch.float
80
81
82
83
84
85

    requires_grad = True
    if dtype_ is not None:
        dtype = dtype_
        requires_grad = False

86
    if recv_prev:
87
        tensor_recv_prev = torch.empty(tensor_chunk_shape,
88
                                       requires_grad=requires_grad,
89
90
91
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)
    if recv_next:
92
        tensor_recv_next = torch.empty(tensor_chunk_shape,
93
                                       requires_grad=requires_grad,
94
95
96
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)

97
    # Split tensor into smaller chunks if using scatter-gather optimization.
98
99
    if not override_scatter_gather_tensors_in_pipeline and \
            args.scatter_gather_tensors_in_pipeline:
100
101
102
103
104
105
        if tensor_send_next is not None:
            tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)

        if tensor_send_prev is not None:
            tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)

106
    # Send tensors in both the forward and backward directions as appropriate.
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    if use_ring_exchange:
        torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
                                        tensor_recv_prev=tensor_recv_prev,
                                        tensor_send_next=tensor_send_next,
                                        tensor_recv_next=tensor_recv_next,
                                        group=mpu.get_pipeline_model_parallel_group())
    else:
        ops = []
        if tensor_send_prev is not None:
            send_prev_op = torch.distributed.P2POp(
                torch.distributed.isend, tensor_send_prev,
                mpu.get_pipeline_model_parallel_prev_rank())
            ops.append(send_prev_op)
        if tensor_recv_prev is not None:
            recv_prev_op = torch.distributed.P2POp(
                torch.distributed.irecv, tensor_recv_prev,
                mpu.get_pipeline_model_parallel_prev_rank())
            ops.append(recv_prev_op)
        if tensor_send_next is not None:
            send_next_op = torch.distributed.P2POp(
                torch.distributed.isend, tensor_send_next,
                mpu.get_pipeline_model_parallel_next_rank())
            ops.append(send_next_op)
        if tensor_recv_next is not None:
            recv_next_op = torch.distributed.P2POp(
                torch.distributed.irecv, tensor_recv_next,
                mpu.get_pipeline_model_parallel_next_rank())
            ops.append(recv_next_op)
        if len(ops) > 0:
            reqs = torch.distributed.batch_isend_irecv(ops)
            for req in reqs:
                req.wait()
139
    # To protect against race condition when using batch_isend_irecv().
140
141
    torch.cuda.synchronize()

142
    # If using scatter-gather optimization, gather smaller chunks.
143
144
    if not override_scatter_gather_tensors_in_pipeline and \
            args.scatter_gather_tensors_in_pipeline:
145
146
147
        if recv_prev:
            tensor_recv_prev = mpu.gather_split_1d_tensor(
                tensor_recv_prev).view(tensor_shape).requires_grad_()
148
149
150
            tensor_recv_prev = make_viewless_tensor(tensor_recv_prev,
                                                    requires_grad = True,
                                                    keep_graph = False)
151
152
153
154

        if recv_next:
            tensor_recv_next = mpu.gather_split_1d_tensor(
                tensor_recv_next).view(tensor_shape).requires_grad_()
155
156
157
            tensor_recv_next = make_viewless_tensor(tensor_recv_next,
                                                    requires_grad = True,
                                                    keep_graph = False)
158
159
160
161

    return tensor_recv_prev, tensor_recv_next


162
def recv_forward(tensor_shape=None, dtype_=None, timers=None):
163
    """Receive tensor from previous rank in pipeline (forward receive)."""
164

165
166
167
168
169
170
171
172
173
    if mpu.is_pipeline_first_stage():
        input_tensor = None
    else:
        if timers is not None:
            timers('forward-recv').start()
        input_tensor, _ = _communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=True,
174
175
176
            recv_next=False,
            tensor_shape=tensor_shape,
            dtype_=dtype_)
177
178
179
180
181
        if timers is not None:
            timers('forward-recv').stop()
    return input_tensor


182
def recv_backward(tensor_shape=None, timers=None):
183
    """Receive tensor from next rank in pipeline (backward receive)."""
184
185
186
187
188
189
190
191
192
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
        if timers is not None:
            timers('backward-recv').start()
        _, output_tensor_grad = _communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=False,
193
194
            recv_next=True,
            tensor_shape=tensor_shape)
195
196
197
198
199
        if timers is not None:
            timers('backward-recv').stop()
    return output_tensor_grad


200
def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
201
    """Send tensor to next rank in pipeline (forward send)."""
202

203
204
205
206
207
208
209
    if not mpu.is_pipeline_last_stage():
        if timers is not None:
            timers('forward-send').start()
        _communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_prev=False,
210
            recv_next=False,
211
            tensor_shape=tensor_shape,
212
            dtype_=dtype_)
213
214
215
216
        if timers is not None:
            timers('forward-send').stop()


217
def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
218
    """Send tensor to previous rank in pipeline (backward send)."""
219
220
221
222
223
224
225
    if not mpu.is_pipeline_first_stage():
        if timers is not None:
            timers('backward-send').start()
        _communicate(
            tensor_send_next=None,
            tensor_send_prev=input_tensor_grad,
            recv_prev=False,
226
227
            recv_next=False,
            tensor_shape=tensor_shape)
228
229
230
231
        if timers is not None:
            timers('backward-send').stop()


232
def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
233
    """Batched send and recv with next rank in pipeline."""
234
235
236
237
238
239
240
241
242
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
        if timers is not None:
            timers('forward-send-backward-recv').start()
        _, output_tensor_grad = _communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_prev=False,
243
244
            recv_next=True,
            tensor_shape=tensor_shape)
245
246
247
248
249
        if timers is not None:
            timers('forward-send-backward-recv').stop()
    return output_tensor_grad


250
def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None):
251
    """Batched send and recv with previous rank in pipeline."""
252
253
254
255
256
257
258
259
260
    if mpu.is_pipeline_first_stage():
        input_tensor = None
    else:
        if timers is not None:
            timers('backward-send-forward-recv').start()
        input_tensor, _ = _communicate(
            tensor_send_next=None,
            tensor_send_prev=input_tensor_grad,
            recv_prev=True,
261
262
            recv_next=False,
            tensor_shape=tensor_shape)
263
264
265
266
267
        if timers is not None:
            timers('backward-send-forward-recv').stop()
    return input_tensor


268
def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timers=None):
269
    """Batched recv from previous rank and send to next rank in pipeline."""
270
271
272
273
274
275
    if timers is not None:
        timers('forward-send-forward-recv').start()
    input_tensor, _ = _communicate(
        tensor_send_next=output_tensor,
        tensor_send_prev=None,
        recv_prev=recv_prev,
276
277
        recv_next=False,
        tensor_shape=tensor_shape)
278
279
280
281
282
    if timers is not None:
        timers('forward-send-forward-recv').stop()
    return input_tensor


283
def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None, timers=None):
284
    """Batched recv from next rank and send to previous rank in pipeline."""
285
286
287
288
289
290
    if timers is not None:
        timers('backward-send-backward-recv').start()
    _, output_tensor_grad = _communicate(
        tensor_send_next=None,
        tensor_send_prev=input_tensor_grad,
        recv_prev=False,
291
292
        recv_next=recv_next,
        tensor_shape=tensor_shape)
293
294
295
296
297
298
299
    if timers is not None:
        timers('backward-send-backward-recv').stop()
    return output_tensor_grad


def send_forward_backward_recv_forward_backward(
        output_tensor, input_tensor_grad, recv_prev,
300
        recv_next, tensor_shape=None, timers=None):
301
    """Batched send and recv with previous and next ranks in pipeline."""
302
303
304
305
306
307
    if timers is not None:
        timers('forward-backward-send-forward-backward-recv').start()
    input_tensor, output_tensor_grad = _communicate(
        tensor_send_next=output_tensor,
        tensor_send_prev=input_tensor_grad,
        recv_prev=recv_prev,
308
309
        recv_next=recv_next,
        tensor_shape=tensor_shape)
310
311
312
    if timers is not None:
        timers('forward-backward-send-forward-backward-recv').stop()
    return input_tensor, output_tensor_grad