p2p_communication.py 10.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
17
from functools import reduce
import operator
18
19
20
21
22
23
24
25
import torch

from megatron import get_args
from megatron import mpu


def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
                 use_ring_exchange=False):
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    """Communicate tensors between stages. Used as helper method in other
    communication methods that are used in megatron/schedules.py.

    Takes the following arguments:
        tensor_send_next: tensor to send to next rank (no tensor sent if
                          set to None).
        tensor_send_prev: tensor to send to prev rank (no tensor sent if
                          set to None).
        recv_prev: boolean for whether tensor should be received from
                   previous rank.
        recv_next: boolean for whether tensor should be received from
                   next rank.
        use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
                           API should be used.

    Returns:
        (tensor_recv_prev, tensor_recv_next)
    """
44
45
46
47
48
49
50
    args = get_args()

    # Create placeholder tensors for receive in forward and backward directions
    # if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
    tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
51
    if args.scatter_gather_tensors_in_pipeline:
52
53
        tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
            mpu.get_tensor_model_parallel_world_size()
54
55
    else:
        tensor_chunk_shape = tensor_shape
56
57
58
59
    dtype = args.params_dtype
    if args.fp32_residual_connection:
        dtype = torch.float
    if recv_prev:
60
        tensor_recv_prev = torch.empty(tensor_chunk_shape,
61
62
63
64
                                       requires_grad=True,
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)
    if recv_next:
65
        tensor_recv_next = torch.empty(tensor_chunk_shape,
66
67
68
69
                                       requires_grad=True,
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)

70
    # Split tensor into smaller chunks if using scatter-gather optimization.
71
72
73
74
75
76
77
    if args.scatter_gather_tensors_in_pipeline:
        if tensor_send_next is not None:
            tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)

        if tensor_send_prev is not None:
            tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)

78
79
80
81
82
83
84
85
86
87
    # Send tensors in both the forward and backward directions as appropriate.
    if use_ring_exchange:
        torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
                                        tensor_recv_prev=tensor_recv_prev,
                                        tensor_send_next=tensor_send_next,
                                        tensor_recv_next=tensor_recv_next,
                                        group=mpu.get_pipeline_model_parallel_group())
    else:
        ops = []
        if tensor_send_prev is not None:
88
89
90
            send_prev_op = torch.distributed.P2POp(
                torch.distributed.isend, tensor_send_prev,
                mpu.get_pipeline_model_parallel_prev_rank())
91
92
            ops.append(send_prev_op)
        if tensor_recv_prev is not None:
93
94
95
            recv_prev_op = torch.distributed.P2POp(
                torch.distributed.irecv, tensor_recv_prev,
                mpu.get_pipeline_model_parallel_prev_rank())
96
97
            ops.append(recv_prev_op)
        if tensor_send_next is not None:
98
99
100
            send_next_op = torch.distributed.P2POp(
                torch.distributed.isend, tensor_send_next,
                mpu.get_pipeline_model_parallel_next_rank())
101
102
            ops.append(send_next_op)
        if tensor_recv_next is not None:
103
104
105
            recv_next_op = torch.distributed.P2POp(
                torch.distributed.irecv, tensor_recv_next,
                mpu.get_pipeline_model_parallel_next_rank())
106
107
108
109
            ops.append(recv_next_op)
        reqs = torch.distributed.batch_isend_irecv(ops)
        for req in reqs:
            req.wait()
110
    # To protect against race condition when using batch_isend_irecv().
111
112
    torch.cuda.synchronize()

113
    # If using scatter-gather optimization, gather smaller chunks.
114
115
116
117
118
119
120
121
    if args.scatter_gather_tensors_in_pipeline:
        if recv_prev:
            tensor_recv_prev = mpu.gather_split_1d_tensor(
                tensor_recv_prev).view(tensor_shape).requires_grad_()

        if recv_next:
            tensor_recv_next = mpu.gather_split_1d_tensor(
                tensor_recv_next).view(tensor_shape).requires_grad_()
122
123
124
125
126

    return tensor_recv_prev, tensor_recv_next


def recv_forward(timers=None, use_ring_exchange=False):
127
    """Receive tensor from previous rank in pipeline (forward receive)."""
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    if mpu.is_pipeline_first_stage():
        input_tensor = None
    else:
        if timers is not None:
            timers('forward-recv').start()
        input_tensor, _ = _communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=True,
            recv_next=False,
            use_ring_exchange=use_ring_exchange)
        if timers is not None:
            timers('forward-recv').stop()
    return input_tensor


def recv_backward(timers=None, use_ring_exchange=False):
145
    """Receive tensor from next rank in pipeline (backward receive)."""
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
        if timers is not None:
            timers('backward-recv').start()
        _, output_tensor_grad = _communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=False,
            recv_next=True,
            use_ring_exchange=use_ring_exchange)
        if timers is not None:
            timers('backward-recv').stop()
    return output_tensor_grad


def send_forward(output_tensor, timers=None, use_ring_exchange=False):
163
    """Send tensor to next rank in pipeline (forward send)."""
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    if not mpu.is_pipeline_last_stage():
        if timers is not None:
            timers('forward-send').start()
        _communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_prev=False,
            recv_next=False,
            use_ring_exchange=use_ring_exchange)
        if timers is not None:
            timers('forward-send').stop()


def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False):
178
    """Send tensor to previous rank in pipeline (backward send)."""
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    if not mpu.is_pipeline_first_stage():
        if timers is not None:
            timers('backward-send').start()
        _communicate(
            tensor_send_next=None,
            tensor_send_prev=input_tensor_grad,
            recv_prev=False,
            recv_next=False,
            use_ring_exchange=use_ring_exchange)
        if timers is not None:
            timers('backward-send').stop()


def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=False):
193
    """Batched send and recv with next rank in pipeline."""
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
        if timers is not None:
            timers('forward-send-backward-recv').start()
        _, output_tensor_grad = _communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_prev=False,
            recv_next=True,
            use_ring_exchange=use_ring_exchange)
        if timers is not None:
            timers('forward-send-backward-recv').stop()
    return output_tensor_grad


def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange=False):
211
    """Batched send and recv with previous rank in pipeline."""
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    if mpu.is_pipeline_first_stage():
        input_tensor = None
    else:
        if timers is not None:
            timers('backward-send-forward-recv').start()
        input_tensor, _ = _communicate(
            tensor_send_next=None,
            tensor_send_prev=input_tensor_grad,
            recv_prev=True,
            recv_next=False,
            use_ring_exchange=use_ring_exchange)
        if timers is not None:
            timers('backward-send-forward-recv').stop()
    return input_tensor


def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
229
    """Batched recv from previous rank and send to next rank in pipeline."""
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    if timers is not None:
        timers('forward-send-forward-recv').start()
    input_tensor, _ = _communicate(
        tensor_send_next=output_tensor,
        tensor_send_prev=None,
        recv_prev=recv_prev,
        recv_next=False,
        use_ring_exchange=True)
    if timers is not None:
        timers('forward-send-forward-recv').stop()
    return input_tensor


def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
244
    """Batched recv from next rank and send to previous rank in pipeline."""
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    if timers is not None:
        timers('backward-send-backward-recv').start()
    _, output_tensor_grad = _communicate(
        tensor_send_next=None,
        tensor_send_prev=input_tensor_grad,
        recv_prev=False,
        recv_next=recv_next,
        use_ring_exchange=True)
    if timers is not None:
        timers('backward-send-backward-recv').stop()
    return output_tensor_grad


def send_forward_backward_recv_forward_backward(
        output_tensor, input_tensor_grad, recv_prev,
        recv_next, timers=None):
261
    """Batched send and recv with previous and next ranks in pipeline."""
262
263
264
265
266
267
268
269
270
271
272
    if timers is not None:
        timers('forward-backward-send-forward-backward-recv').start()
    input_tensor, output_tensor_grad = _communicate(
        tensor_send_next=output_tensor,
        tensor_send_prev=input_tensor_grad,
        recv_prev=recv_prev,
        recv_next=recv_next,
        use_ring_exchange=True)
    if timers is not None:
        timers('forward-backward-send-forward-backward-recv').stop()
    return input_tensor, output_tensor_grad