p2p_communication.py 9.28 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
17
from functools import reduce
import operator
18
19
20
21
22
23
import torch

from megatron import get_args
from megatron import mpu


24
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    """Communicate tensors between stages. Used as helper method in other
    communication methods that are used in megatron/schedules.py.

    Takes the following arguments:
        tensor_send_next: tensor to send to next rank (no tensor sent if
                          set to None).
        tensor_send_prev: tensor to send to prev rank (no tensor sent if
                          set to None).
        recv_prev: boolean for whether tensor should be received from
                   previous rank.
        recv_next: boolean for whether tensor should be received from
                   next rank.

    Returns:
        (tensor_recv_prev, tensor_recv_next)
    """
41
42
43
44
45
46
47
    args = get_args()

    # Create placeholder tensors for receive in forward and backward directions
    # if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
    tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
48
    if args.scatter_gather_tensors_in_pipeline:
49
50
        tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
            mpu.get_tensor_model_parallel_world_size()
51
52
    else:
        tensor_chunk_shape = tensor_shape
53
54
55
56
    dtype = args.params_dtype
    if args.fp32_residual_connection:
        dtype = torch.float
    if recv_prev:
57
        tensor_recv_prev = torch.empty(tensor_chunk_shape,
58
59
60
61
                                       requires_grad=True,
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)
    if recv_next:
62
        tensor_recv_next = torch.empty(tensor_chunk_shape,
63
64
65
66
                                       requires_grad=True,
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)

67
    # Split tensor into smaller chunks if using scatter-gather optimization.
68
69
70
71
72
73
74
    if args.scatter_gather_tensors_in_pipeline:
        if tensor_send_next is not None:
            tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)

        if tensor_send_prev is not None:
            tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)

75
    # Send tensors in both the forward and backward directions as appropriate.
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    ops = []
    if tensor_send_prev is not None:
        send_prev_op = torch.distributed.P2POp(
            torch.distributed.isend, tensor_send_prev,
            mpu.get_pipeline_model_parallel_prev_rank())
        ops.append(send_prev_op)
    if tensor_recv_prev is not None:
        recv_prev_op = torch.distributed.P2POp(
            torch.distributed.irecv, tensor_recv_prev,
            mpu.get_pipeline_model_parallel_prev_rank())
        ops.append(recv_prev_op)
    if tensor_send_next is not None:
        send_next_op = torch.distributed.P2POp(
            torch.distributed.isend, tensor_send_next,
            mpu.get_pipeline_model_parallel_next_rank())
        ops.append(send_next_op)
    if tensor_recv_next is not None:
        recv_next_op = torch.distributed.P2POp(
            torch.distributed.irecv, tensor_recv_next,
            mpu.get_pipeline_model_parallel_next_rank())
        ops.append(recv_next_op)
    if len(ops) > 0:
98
99
100
        reqs = torch.distributed.batch_isend_irecv(ops)
        for req in reqs:
            req.wait()
101
    # To protect against race condition when using batch_isend_irecv().
102
103
    torch.cuda.synchronize()

104
    # If using scatter-gather optimization, gather smaller chunks.
105
106
107
108
109
110
111
112
    if args.scatter_gather_tensors_in_pipeline:
        if recv_prev:
            tensor_recv_prev = mpu.gather_split_1d_tensor(
                tensor_recv_prev).view(tensor_shape).requires_grad_()

        if recv_next:
            tensor_recv_next = mpu.gather_split_1d_tensor(
                tensor_recv_next).view(tensor_shape).requires_grad_()
113
114
115
116

    return tensor_recv_prev, tensor_recv_next


117
def recv_forward(timers=None):
118
    """Receive tensor from previous rank in pipeline (forward receive)."""
119
120
121
122
123
124
125
126
127
    if mpu.is_pipeline_first_stage():
        input_tensor = None
    else:
        if timers is not None:
            timers('forward-recv').start()
        input_tensor, _ = _communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=True,
128
            recv_next=False)
129
130
131
132
133
        if timers is not None:
            timers('forward-recv').stop()
    return input_tensor


134
def recv_backward(timers=None):
135
    """Receive tensor from next rank in pipeline (backward receive)."""
136
137
138
139
140
141
142
143
144
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
        if timers is not None:
            timers('backward-recv').start()
        _, output_tensor_grad = _communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_prev=False,
145
            recv_next=True)
146
147
148
149
150
        if timers is not None:
            timers('backward-recv').stop()
    return output_tensor_grad


151
def send_forward(output_tensor, timers=None):
152
    """Send tensor to next rank in pipeline (forward send)."""
153
154
155
156
157
158
159
    if not mpu.is_pipeline_last_stage():
        if timers is not None:
            timers('forward-send').start()
        _communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_prev=False,
160
            recv_next=False)
161
162
163
164
        if timers is not None:
            timers('forward-send').stop()


165
def send_backward(input_tensor_grad, timers=None):
166
    """Send tensor to previous rank in pipeline (backward send)."""
167
168
169
170
171
172
173
    if not mpu.is_pipeline_first_stage():
        if timers is not None:
            timers('backward-send').start()
        _communicate(
            tensor_send_next=None,
            tensor_send_prev=input_tensor_grad,
            recv_prev=False,
174
            recv_next=False)
175
176
177
178
        if timers is not None:
            timers('backward-send').stop()


179
def send_forward_recv_backward(output_tensor, timers=None):
180
    """Batched send and recv with next rank in pipeline."""
181
182
183
184
185
186
187
188
189
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
        if timers is not None:
            timers('forward-send-backward-recv').start()
        _, output_tensor_grad = _communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_prev=False,
190
            recv_next=True)
191
192
193
194
195
        if timers is not None:
            timers('forward-send-backward-recv').stop()
    return output_tensor_grad


196
def send_backward_recv_forward(input_tensor_grad, timers=None):
197
    """Batched send and recv with previous rank in pipeline."""
198
199
200
201
202
203
204
205
206
    if mpu.is_pipeline_first_stage():
        input_tensor = None
    else:
        if timers is not None:
            timers('backward-send-forward-recv').start()
        input_tensor, _ = _communicate(
            tensor_send_next=None,
            tensor_send_prev=input_tensor_grad,
            recv_prev=True,
207
            recv_next=False)
208
209
210
211
212
213
        if timers is not None:
            timers('backward-send-forward-recv').stop()
    return input_tensor


def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
214
    """Batched recv from previous rank and send to next rank in pipeline."""
215
216
217
218
219
220
    if timers is not None:
        timers('forward-send-forward-recv').start()
    input_tensor, _ = _communicate(
        tensor_send_next=output_tensor,
        tensor_send_prev=None,
        recv_prev=recv_prev,
221
        recv_next=False)
222
223
224
225
226
227
    if timers is not None:
        timers('forward-send-forward-recv').stop()
    return input_tensor


def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
228
    """Batched recv from next rank and send to previous rank in pipeline."""
229
230
231
232
233
234
    if timers is not None:
        timers('backward-send-backward-recv').start()
    _, output_tensor_grad = _communicate(
        tensor_send_next=None,
        tensor_send_prev=input_tensor_grad,
        recv_prev=False,
235
        recv_next=recv_next)
236
237
238
239
240
241
242
243
    if timers is not None:
        timers('backward-send-backward-recv').stop()
    return output_tensor_grad


def send_forward_backward_recv_forward_backward(
        output_tensor, input_tensor_grad, recv_prev,
        recv_next, timers=None):
244
    """Batched send and recv with previous and next ranks in pipeline."""
245
246
247
248
249
250
    if timers is not None:
        timers('forward-backward-send-forward-backward-recv').start()
    input_tensor, output_tensor_grad = _communicate(
        tensor_send_next=output_tensor,
        tensor_send_prev=input_tensor_grad,
        recv_prev=recv_prev,
251
        recv_next=recv_next)
252
253
254
    if timers is not None:
        timers('forward-backward-send-forward-backward-recv').stop()
    return input_tensor, output_tensor_grad