p2p_communication.py 11.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
17
from functools import reduce
import operator
18
19
20
21
22
23
import torch

from megatron import get_args
from megatron import mpu


24
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
25
26
                 tensor_shape,
                 use_ring_exchange=False,
27
                 dtype_=None):
28
29
30
31
32
33
34
35
36
37
38
39
    """Communicate tensors between stages. Used as helper method in other
    communication methods that are used in megatron/schedules.py.

    Takes the following arguments:
        tensor_send_next: tensor to send to next rank (no tensor sent if
                          set to None).
        tensor_send_prev: tensor to send to prev rank (no tensor sent if
                          set to None).
        recv_prev: boolean for whether tensor should be received from
                   previous rank.
        recv_next: boolean for whether tensor should be received from
                   next rank.
40
41
42
        tensor_shape: shape of tensor to receive (this method assumes that all
                      tensors sent and received in a single function call are
                      the same shape).
43
44
        use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
                           API should be used.
45
46
        dtype_: optional, this is used when the tensor that needs to be
                communicated is different from args.params_dtype.
47
48
49
    Returns:
        (tensor_recv_prev, tensor_recv_next)
    """
50
51
52
53
54
55
    args = get_args()

    # Create placeholder tensors for receive in forward and backward directions
    # if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
56
57
58
59
60
61
62
63
64
    override_scatter_gather_tensors_in_pipeline = False
    if args.scatter_gather_tensors_in_pipeline:
        tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
        if tensor_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0:
            tensor_chunk_shape = tensor_chunk_shape // \
                mpu.get_tensor_model_parallel_world_size()
        else:
            tensor_chunk_shape = tensor_shape
            override_scatter_gather_tensors_in_pipeline = True
65
66
    else:
        tensor_chunk_shape = tensor_shape
67
68
69
    dtype = args.params_dtype
    if args.fp32_residual_connection:
        dtype = torch.float
70
71
72
73
74
75

    requires_grad = True
    if dtype_ is not None:
        dtype = dtype_
        requires_grad = False

76
    if recv_prev:
77
        tensor_recv_prev = torch.empty(tensor_chunk_shape,
78
                                       requires_grad=requires_grad,
79
80
81
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)
    if recv_next:
82
        tensor_recv_next = torch.empty(tensor_chunk_shape,
83
                                       requires_grad=requires_grad,
84
85
86
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)

87
    # Split tensor into smaller chunks if using scatter-gather optimization.
88
89
    if not override_scatter_gather_tensors_in_pipeline and \
            args.scatter_gather_tensors_in_pipeline:
90
91
92
93
94
95
        if tensor_send_next is not None:
            tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)

        if tensor_send_prev is not None:
            tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)

96
    # Send tensors in both the forward and backward directions as appropriate.
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    if use_ring_exchange:
        torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
                                        tensor_recv_prev=tensor_recv_prev,
                                        tensor_send_next=tensor_send_next,
                                        tensor_recv_next=tensor_recv_next,
                                        group=mpu.get_pipeline_model_parallel_group())
    else:
        ops = []
        if tensor_send_prev is not None:
            send_prev_op = torch.distributed.P2POp(
                torch.distributed.isend, tensor_send_prev,
                mpu.get_pipeline_model_parallel_prev_rank())
            ops.append(send_prev_op)
        if tensor_recv_prev is not None:
            recv_prev_op = torch.distributed.P2POp(
                torch.distributed.irecv, tensor_recv_prev,
                mpu.get_pipeline_model_parallel_prev_rank())
            ops.append(recv_prev_op)
        if tensor_send_next is not None:
            send_next_op = torch.distributed.P2POp(
                torch.distributed.isend, tensor_send_next,
                mpu.get_pipeline_model_parallel_next_rank())
            ops.append(send_next_op)
        if tensor_recv_next is not None:
            recv_next_op = torch.distributed.P2POp(
                torch.distributed.irecv, tensor_recv_next,
                mpu.get_pipeline_model_parallel_next_rank())
            ops.append(recv_next_op)
        if len(ops) > 0:
            reqs = torch.distributed.batch_isend_irecv(ops)
            for req in reqs:
                req.wait()
129
    # To protect against race condition when using batch_isend_irecv().
130
131
    torch.cuda.synchronize()

132
    # If using scatter-gather optimization, gather smaller chunks.
133
134
    if not override_scatter_gather_tensors_in_pipeline and \
            args.scatter_gather_tensors_in_pipeline:
135
136
137
138
139
140
141
        if recv_prev:
            tensor_recv_prev = mpu.gather_split_1d_tensor(
                tensor_recv_prev).view(tensor_shape).requires_grad_()

        if recv_next:
            tensor_recv_next = mpu.gather_split_1d_tensor(
                tensor_recv_next).view(tensor_shape).requires_grad_()
142
143
144
145

    return tensor_recv_prev, tensor_recv_next


146
def recv_forward(tensor_shape, dtype_=None, timers=None):
147
    """Receive tensor from previous rank in pipeline (forward receive)."""
148

149
150
151
152
153
154
155
156
157
    if mpu.is_pipeline_first_stage():
        input_tensor = None
    else:
        if timers is not None:
            timers('forward-recv').start()
        input_tensor, _ = _communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=True,
158
159
160
            recv_next=False,
            tensor_shape=tensor_shape,
            dtype_=dtype_)
161
162
163
164
165
        if timers is not None:
            timers('forward-recv').stop()
    return input_tensor


166
def recv_backward(tensor_shape, timers=None):
167
    """Receive tensor from next rank in pipeline (backward receive)."""
168
169
170
171
172
173
174
175
176
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
        if timers is not None:
            timers('backward-recv').start()
        _, output_tensor_grad = _communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=False,
177
178
            recv_next=True,
            tensor_shape=tensor_shape)
179
180
181
182
183
        if timers is not None:
            timers('backward-recv').stop()
    return output_tensor_grad


184
def send_forward(output_tensor, tensor_shape, dtype_=None, timers=None):
185
    """Send tensor to next rank in pipeline (forward send)."""
186

187
188
189
190
191
192
193
    if not mpu.is_pipeline_last_stage():
        if timers is not None:
            timers('forward-send').start()
        _communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_prev=False,
194
            recv_next=False,
195
            tensor_shape=tensor_shape,
196
            dtype_=dtype_)
197
198
199
200
        if timers is not None:
            timers('forward-send').stop()


201
def send_backward(input_tensor_grad, tensor_shape, timers=None):
202
    """Send tensor to previous rank in pipeline (backward send)."""
203
204
205
206
207
208
209
    if not mpu.is_pipeline_first_stage():
        if timers is not None:
            timers('backward-send').start()
        _communicate(
            tensor_send_next=None,
            tensor_send_prev=input_tensor_grad,
            recv_prev=False,
210
211
            recv_next=False,
            tensor_shape=tensor_shape)
212
213
214
215
        if timers is not None:
            timers('backward-send').stop()


216
def send_forward_recv_backward(output_tensor, tensor_shape, timers=None):
217
    """Batched send and recv with next rank in pipeline."""
218
219
220
221
222
223
224
225
226
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
        if timers is not None:
            timers('forward-send-backward-recv').start()
        _, output_tensor_grad = _communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_prev=False,
227
228
            recv_next=True,
            tensor_shape=tensor_shape)
229
230
231
232
233
        if timers is not None:
            timers('forward-send-backward-recv').stop()
    return output_tensor_grad


234
def send_backward_recv_forward(input_tensor_grad, tensor_shape, timers=None):
235
    """Batched send and recv with previous rank in pipeline."""
236
237
238
239
240
241
242
243
244
    if mpu.is_pipeline_first_stage():
        input_tensor = None
    else:
        if timers is not None:
            timers('backward-send-forward-recv').start()
        input_tensor, _ = _communicate(
            tensor_send_next=None,
            tensor_send_prev=input_tensor_grad,
            recv_prev=True,
245
246
            recv_next=False,
            tensor_shape=tensor_shape)
247
248
249
250
251
        if timers is not None:
            timers('backward-send-forward-recv').stop()
    return input_tensor


252
def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape, timers=None):
253
    """Batched recv from previous rank and send to next rank in pipeline."""
254
255
256
257
258
259
    if timers is not None:
        timers('forward-send-forward-recv').start()
    input_tensor, _ = _communicate(
        tensor_send_next=output_tensor,
        tensor_send_prev=None,
        recv_prev=recv_prev,
260
261
        recv_next=False,
        tensor_shape=tensor_shape)
262
263
264
265
266
    if timers is not None:
        timers('forward-send-forward-recv').stop()
    return input_tensor


267
def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape, timers=None):
268
    """Batched recv from next rank and send to previous rank in pipeline."""
269
270
271
272
273
274
    if timers is not None:
        timers('backward-send-backward-recv').start()
    _, output_tensor_grad = _communicate(
        tensor_send_next=None,
        tensor_send_prev=input_tensor_grad,
        recv_prev=False,
275
276
        recv_next=recv_next,
        tensor_shape=tensor_shape)
277
278
279
280
281
282
283
    if timers is not None:
        timers('backward-send-backward-recv').stop()
    return output_tensor_grad


def send_forward_backward_recv_forward_backward(
        output_tensor, input_tensor_grad, recv_prev,
284
        recv_next, tensor_shape, timers=None):
285
    """Batched send and recv with previous and next ranks in pipeline."""
286
287
288
289
290
291
    if timers is not None:
        timers('forward-backward-send-forward-backward-recv').start()
    input_tensor, output_tensor_grad = _communicate(
        tensor_send_next=output_tensor,
        tensor_send_prev=input_tensor_grad,
        recv_prev=recv_prev,
292
293
        recv_next=recv_next,
        tensor_shape=tensor_shape)
294
295
296
    if timers is not None:
        timers('forward-backward-send-forward-backward-recv').stop()
    return input_tensor, output_tensor_grad