p2p_communication.py 12.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
17
from functools import reduce
import operator
18
19
20
21
22
23
import torch

from megatron import get_args
from megatron import mpu


24
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
25
26
                 tensor_shape,
                 use_ring_exchange=False,
27
                 dtype_=None):
28
29
30
31
32
33
34
35
36
37
38
39
    """Communicate tensors between stages. Used as helper method in other
    communication methods that are used in megatron/schedules.py.

    Takes the following arguments:
        tensor_send_next: tensor to send to next rank (no tensor sent if
                          set to None).
        tensor_send_prev: tensor to send to prev rank (no tensor sent if
                          set to None).
        recv_prev: boolean for whether tensor should be received from
                   previous rank.
        recv_next: boolean for whether tensor should be received from
                   next rank.
40
41
42
        tensor_shape: shape of tensor to receive (this method assumes that all
                      tensors sent and received in a single function call are
                      the same shape).
43
44
        use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
                           API should be used.
45
46
        dtype_: optional, this is used when the tensor that needs to be
                communicated is different from args.params_dtype.
47
48
49
    Returns:
        (tensor_recv_prev, tensor_recv_next)
    """
50
51
52
53
54
55
    args = get_args()

    # Create placeholder tensors for receive in forward and backward directions
    # if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
56
57
58
59
60
61
62

    # Some legacy inference code doesn't set the tensor shape, do so now
    # for the normal values for gpt/bert. This could be removed if inference
    # code is changed to provide tensor_shape.
    if tensor_shape is None:
        tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)

63
64
65
66
67
68
69
70
71
    override_scatter_gather_tensors_in_pipeline = False
    if args.scatter_gather_tensors_in_pipeline:
        tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
        if tensor_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0:
            tensor_chunk_shape = tensor_chunk_shape // \
                mpu.get_tensor_model_parallel_world_size()
        else:
            tensor_chunk_shape = tensor_shape
            override_scatter_gather_tensors_in_pipeline = True
72
73
    else:
        tensor_chunk_shape = tensor_shape
74
75
76
    dtype = args.params_dtype
    if args.fp32_residual_connection:
        dtype = torch.float
77
78
79
80
81
82

    requires_grad = True
    if dtype_ is not None:
        dtype = dtype_
        requires_grad = False

83
    if recv_prev:
84
        tensor_recv_prev = torch.empty(tensor_chunk_shape,
85
                                       requires_grad=requires_grad,
86
87
88
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)
    if recv_next:
89
        tensor_recv_next = torch.empty(tensor_chunk_shape,
90
                                       requires_grad=requires_grad,
91
92
93
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)

94
    # Split tensor into smaller chunks if using scatter-gather optimization.
95
96
    if not override_scatter_gather_tensors_in_pipeline and \
            args.scatter_gather_tensors_in_pipeline:
97
98
99
100
101
102
        if tensor_send_next is not None:
            tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)

        if tensor_send_prev is not None:
            tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)

103
    # Send tensors in both the forward and backward directions as appropriate.
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    if use_ring_exchange:
        torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
                                        tensor_recv_prev=tensor_recv_prev,
                                        tensor_send_next=tensor_send_next,
                                        tensor_recv_next=tensor_recv_next,
                                        group=mpu.get_pipeline_model_parallel_group())
    else:
        ops = []
        if tensor_send_prev is not None:
            send_prev_op = torch.distributed.P2POp(
                torch.distributed.isend, tensor_send_prev,
                mpu.get_pipeline_model_parallel_prev_rank())
            ops.append(send_prev_op)
        if tensor_recv_prev is not None:
            recv_prev_op = torch.distributed.P2POp(
                torch.distributed.irecv, tensor_recv_prev,
                mpu.get_pipeline_model_parallel_prev_rank())
            ops.append(recv_prev_op)
        if tensor_send_next is not None:
            send_next_op = torch.distributed.P2POp(
                torch.distributed.isend, tensor_send_next,
                mpu.get_pipeline_model_parallel_next_rank())
            ops.append(send_next_op)
        if tensor_recv_next is not None:
            recv_next_op = torch.distributed.P2POp(
                torch.distributed.irecv, tensor_recv_next,
                mpu.get_pipeline_model_parallel_next_rank())
            ops.append(recv_next_op)
        if len(ops) > 0:
            reqs = torch.distributed.batch_isend_irecv(ops)
            for req in reqs:
                req.wait()
136
    # To protect against race condition when using batch_isend_irecv().
137
138
    torch.cuda.synchronize()

139
140
141
142
143
    # >>>
    def make_viewless_tensor(t):
        return mpu.make_viewless_tensor(t, requires_grad=True, keep_graph=False)
    # <<<

144
    # If using scatter-gather optimization, gather smaller chunks.
145
146
    if not override_scatter_gather_tensors_in_pipeline and \
            args.scatter_gather_tensors_in_pipeline:
147
148
149
        if recv_prev:
            tensor_recv_prev = mpu.gather_split_1d_tensor(
                tensor_recv_prev).view(tensor_shape).requires_grad_()
150
151
152
153
154
155
156
            # >>>
            # tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev,
            #                                             requires_grad = True,
            #                                             keep_graph = False)
            # +++
            tensor_recv_prev = make_viewless_tensor(tensor_recv_prev)
            # <<<
157
158
159
160

        if recv_next:
            tensor_recv_next = mpu.gather_split_1d_tensor(
                tensor_recv_next).view(tensor_shape).requires_grad_()
161
162
163
164
165
166
167
            # >>>
            # tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next,
            #                                             requires_grad = True,
            #                                             keep_graph = False)
            # +++
            tensor_recv_next = make_viewless_tensor(tensor_recv_next)
            # <<<
168
169
170
171

    return tensor_recv_prev, tensor_recv_next


172
def recv_forward(tensor_shape=None, dtype_=None, timers=None):
173
    """Receive tensor from previous rank in pipeline (forward receive)."""
174

175
176
177
178
179
180
181
182
183
    if mpu.is_pipeline_first_stage():
        input_tensor = None
    else:
        if timers is not None:
            timers('forward-recv').start()
        input_tensor, _ = _communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=True,
184
185
186
            recv_next=False,
            tensor_shape=tensor_shape,
            dtype_=dtype_)
187
188
189
190
191
        if timers is not None:
            timers('forward-recv').stop()
    return input_tensor


192
def recv_backward(tensor_shape=None, timers=None):
193
    """Receive tensor from next rank in pipeline (backward receive)."""
194
195
196
197
198
199
200
201
202
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
        if timers is not None:
            timers('backward-recv').start()
        _, output_tensor_grad = _communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=False,
203
204
            recv_next=True,
            tensor_shape=tensor_shape)
205
206
207
208
209
        if timers is not None:
            timers('backward-recv').stop()
    return output_tensor_grad


210
def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
211
    """Send tensor to next rank in pipeline (forward send)."""
212

213
214
215
216
217
218
219
    if not mpu.is_pipeline_last_stage():
        if timers is not None:
            timers('forward-send').start()
        _communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_prev=False,
220
            recv_next=False,
221
            tensor_shape=tensor_shape,
222
            dtype_=dtype_)
223
224
225
226
        if timers is not None:
            timers('forward-send').stop()


227
def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
228
    """Send tensor to previous rank in pipeline (backward send)."""
229
230
231
232
233
234
235
    if not mpu.is_pipeline_first_stage():
        if timers is not None:
            timers('backward-send').start()
        _communicate(
            tensor_send_next=None,
            tensor_send_prev=input_tensor_grad,
            recv_prev=False,
236
237
            recv_next=False,
            tensor_shape=tensor_shape)
238
239
240
241
        if timers is not None:
            timers('backward-send').stop()


242
def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
243
    """Batched send and recv with next rank in pipeline."""
244
245
246
247
248
249
250
251
252
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
        if timers is not None:
            timers('forward-send-backward-recv').start()
        _, output_tensor_grad = _communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_prev=False,
253
254
            recv_next=True,
            tensor_shape=tensor_shape)
255
256
257
258
259
        if timers is not None:
            timers('forward-send-backward-recv').stop()
    return output_tensor_grad


260
def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None):
261
    """Batched send and recv with previous rank in pipeline."""
262
263
264
265
266
267
268
269
270
    if mpu.is_pipeline_first_stage():
        input_tensor = None
    else:
        if timers is not None:
            timers('backward-send-forward-recv').start()
        input_tensor, _ = _communicate(
            tensor_send_next=None,
            tensor_send_prev=input_tensor_grad,
            recv_prev=True,
271
272
            recv_next=False,
            tensor_shape=tensor_shape)
273
274
275
276
277
        if timers is not None:
            timers('backward-send-forward-recv').stop()
    return input_tensor


278
def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timers=None):
279
    """Batched recv from previous rank and send to next rank in pipeline."""
280
281
282
283
284
285
    if timers is not None:
        timers('forward-send-forward-recv').start()
    input_tensor, _ = _communicate(
        tensor_send_next=output_tensor,
        tensor_send_prev=None,
        recv_prev=recv_prev,
286
287
        recv_next=False,
        tensor_shape=tensor_shape)
288
289
290
291
292
    if timers is not None:
        timers('forward-send-forward-recv').stop()
    return input_tensor


293
def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None, timers=None):
294
    """Batched recv from next rank and send to previous rank in pipeline."""
295
296
297
298
299
300
    if timers is not None:
        timers('backward-send-backward-recv').start()
    _, output_tensor_grad = _communicate(
        tensor_send_next=None,
        tensor_send_prev=input_tensor_grad,
        recv_prev=False,
301
302
        recv_next=recv_next,
        tensor_shape=tensor_shape)
303
304
305
306
307
308
309
    if timers is not None:
        timers('backward-send-backward-recv').stop()
    return output_tensor_grad


def send_forward_backward_recv_forward_backward(
        output_tensor, input_tensor_grad, recv_prev,
310
        recv_next, tensor_shape=None, timers=None):
311
    """Batched send and recv with previous and next ranks in pipeline."""
312
313
314
315
316
317
    if timers is not None:
        timers('forward-backward-send-forward-backward-recv').start()
    input_tensor, output_tensor_grad = _communicate(
        tensor_send_next=output_tensor,
        tensor_send_prev=input_tensor_grad,
        recv_prev=recv_prev,
318
319
        recv_next=recv_next,
        tensor_shape=tensor_shape)
320
321
322
    if timers is not None:
        timers('forward-backward-send-forward-backward-recv').stop()
    return input_tensor, output_tensor_grad