schedules.py 18.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
from contextlib import contextmanager
17
import torch
18
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
19
20

from megatron import get_args
21
from megatron import get_num_microbatches
22
23
from megatron import get_timers
from megatron import mpu
24
from megatron import p2p_communication
25
26


Jared Casper's avatar
Jared Casper committed
27
28
29
30
31
32
33
34
35
36
37
38
def get_forward_backward_func():
    args = get_args()
    if mpu.get_pipeline_model_parallel_world_size() > 1:
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
        else:
            forward_backward_func = forward_backward_pipelining_without_interleaving
    else:
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func


39
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
40
41
42
43
44
45
    """Forward step for passed-in model.

    If first stage, input tensor is obtained from data_iterator, otherwise
    passed-in input_tensor is used.

    Returns output tensor."""
46
47
48
    timers = get_timers()

    timers('forward-compute').start()
Jared Casper's avatar
Jared Casper committed
49
    # TODO
50
51
    model.module.module.set_input_tensor(input_tensor)
    output_tensor, loss_func = forward_step_func(data_iterator, model)
52
    if mpu.is_pipeline_last_stage():
53
        output_tensor = loss_func(output_tensor)
54
55
56
57
58
59
60
61
62
        loss, loss_reduced = output_tensor
        output_tensor = loss / get_num_microbatches()
        losses_reduced.append(loss_reduced)
    timers('forward-compute').stop()

    return output_tensor


def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
63
64
65
66
67
68
69
    """Backward step through passed-in output tensor.

    If last stage, output_tensor_grad is None, otherwise gradient of loss
    with respect to stage's output tensor.

    Returns gradient of loss with respect to input tensor (None if first
    stage)."""
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    args = get_args()

    timers = get_timers()
    timers('backward-compute').start()

    # Retain the grad on the input_tensor.
    if input_tensor is not None:
        input_tensor.retain_grad()

    # Backward pass.
    if output_tensor_grad is None:
        output_tensor = optimizer.scale_loss(output_tensor)
    torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)

    # Collect the grad of the input_tensor.
    input_tensor_grad = None
    if input_tensor is not None:
        input_tensor_grad = input_tensor.grad

    timers('backward-compute').stop()

    return input_tensor_grad


94
95
96
97
98
99
100
101
@contextmanager
def dummy_handler():
    try:
        yield
    finally:
        pass


102
103
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
                                   optimizer, timers, forward_only):
104
105
106
107
    """Run forward and backward passes with no pipeline parallelism
    (no inter-stage communication).

    Returns dictionary with losses."""
108
109
110
    assert len(model) == 1
    model = model[0]

111
112
113
114
    context_handler = dummy_handler
    if isinstance(model, torchDDP):
        context_handler = model.no_sync

115
    losses_reduced = []
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    input_tensor, output_tensor_grad = None, None
    with context_handler():
        for i in range(get_num_microbatches() - 1):
            output_tensor = forward_step(forward_step_func, data_iterator, model,
                                         input_tensor, losses_reduced)
            if not forward_only:
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

    # Run computation for last microbatch out of context handler (want to
    # synchronize gradients).
    output_tensor = forward_step(forward_step_func, data_iterator, model,
                                 input_tensor, losses_reduced)
    if not forward_only:
        backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
131
132
133
134
135
136

    return losses_reduced


def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model,
                                                  optimizer, timers, forward_only):
137
138
139
140
    """Run interleaved 1F1B schedule (model split into model chunks), with
    communication between pipeline stages as needed.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
141
142
143
144
145
146
147
    input_tensors = [[] for _ in range(len(model))]
    output_tensors = [[] for _ in range(len(model))]
    losses_reduced = []
    if not forward_only:
        output_tensor_grads = [[] for _ in range(len(model))]

    pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
148
    pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
149
150
151
152
153
154
155
156

    # Compute number of warmup and remaining microbatches.
    num_model_chunks = len(model)
    num_microbatches = get_num_microbatches() * num_model_chunks
    all_warmup_microbatches = False
    if forward_only:
        num_warmup_microbatches = num_microbatches
    else:
157
158
159
160
161
162
        # Run all forward passes and then all backward passes if number of
        # microbatches is just the number of pipeline stages.
        # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
        # all workers, followed by more microbatches after depending on
        # stage ID (more forward passes for earlier stages, later stages can
        # immediately start with 1F1B).
163
164
165
166
167
        if get_num_microbatches() == pipeline_parallel_size:
            num_warmup_microbatches = num_microbatches
            all_warmup_microbatches = True
        else:
            num_warmup_microbatches = \
168
                (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
169
170
171
172
            num_warmup_microbatches += (
                num_model_chunks - 1) * pipeline_parallel_size
            num_warmup_microbatches = min(num_warmup_microbatches,
                                          num_microbatches)
173
174
175
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

176
    def get_model_chunk_id(microbatch_id, forward):
177
        """Helper method to get the model chunk ID given the iteration number."""
178
179
        microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
        model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
180
        if not forward:
181
182
            model_chunk_id = (num_model_chunks - model_chunk_id - 1)
        return model_chunk_id
183

184
    def forward_step_helper(microbatch_id):
185
186
187
        """Helper method to run forward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        forward_step())."""
188
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
189
190
191
        mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        if mpu.is_pipeline_first_stage():
192
193
            if len(input_tensors[model_chunk_id]) == \
                    len(output_tensors[model_chunk_id]):
194
195
                input_tensors[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id][-1]
196
197
        output_tensor = forward_step(forward_step_func,
                                     data_iterator[model_chunk_id],
198
199
200
201
202
203
                                     model[model_chunk_id],
                                     input_tensor, losses_reduced)
        output_tensors[model_chunk_id].append(output_tensor)

        return output_tensor

204
    def backward_step_helper(microbatch_id):
205
206
207
        """Helper method to run backward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        backward_step())."""
208
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
209
210
211
212
213
214
215
216
217
        mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        if mpu.is_pipeline_last_stage():
            if len(output_tensor_grads[model_chunk_id]) == 0:
                output_tensor_grads[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id].pop(0)
        output_tensor = output_tensors[model_chunk_id].pop(0)
        output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
        input_tensor_grad = \
218
219
220
221
            backward_step(optimizer,
                          input_tensor,
                          output_tensor,
                          output_tensor_grad)
222
223
224
225
226

        return input_tensor_grad

    # Run warmup forward passes.
    mpu.set_virtual_pipeline_model_parallel_rank(0)
227
    input_tensors[0].append(
228
        p2p_communication.recv_forward(timers))
229
230
    for k in range(num_warmup_microbatches):
        output_tensor = forward_step_helper(k)
231
232

        # Determine if tensor should be received from previous stage.
233
234
235
236
237
238
239
        next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
        recv_prev = True
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            if next_forward_model_chunk_id == 0:
                recv_prev = False
        if k == (num_microbatches - 1):
            recv_prev = False
240
241

        # Don't send tensor downstream if on last stage.
242
243
        if mpu.is_pipeline_last_stage():
            output_tensor = None
244
245
246

        # Send and receive tensors as appropriate (send tensors computed
        # in this iteration; receive tensors for next iteration).
247
248
249
250
251
252
253
        if k == (num_warmup_microbatches - 1) and not forward_only and \
                not all_warmup_microbatches:
            input_tensor_grad = None
            recv_next = True
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                recv_next = False
            input_tensor, output_tensor_grad = \
254
                p2p_communication.send_forward_backward_recv_forward_backward(
255
256
257
258
259
                        output_tensor, input_tensor_grad,
                        recv_prev=recv_prev, recv_next=recv_next,
                        timers=timers)
            output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
        else:
260
            input_tensor = \
261
262
                p2p_communication.send_forward_recv_forward(
                    output_tensor, recv_prev, timers)
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        input_tensors[next_forward_model_chunk_id].append(input_tensor)

    # Run 1F1B in steady state.
    for k in range(num_microbatches_remaining):
        # Forward pass.
        forward_k = k + num_warmup_microbatches
        output_tensor = forward_step_helper(forward_k)

        # Backward pass.
        backward_k = k
        input_tensor_grad = backward_step_helper(backward_k)

        # Send output_tensor and input_tensor_grad, receive input_tensor
        # and output_tensor_grad.

        # Determine if current stage has anything to send in either direction,
        # otherwise set tensor to None.
        forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
        mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
        if mpu.is_pipeline_last_stage():
            output_tensor = None

        backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
        mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
        if mpu.is_pipeline_first_stage():
            input_tensor_grad = None

        # Determine if peers are sending, and where in data structure to put
        # received tensors.
        recv_prev = True
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            # First stage is ahead of last stage by (pipeline_parallel_size - 1).
            next_forward_model_chunk_id = get_model_chunk_id(
                forward_k - (pipeline_parallel_size - 1), forward=True)
            if next_forward_model_chunk_id == (num_model_chunks - 1):
                recv_prev = False
            next_forward_model_chunk_id += 1
        else:
301
302
            next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
                                                             forward=True)
303
304
305
306
307
308
309
310
311
312

        recv_next = True
        if mpu.is_pipeline_last_stage(ignore_virtual=True):
            # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
            next_backward_model_chunk_id = get_model_chunk_id(
                backward_k - (pipeline_parallel_size - 1), forward=False)
            if next_backward_model_chunk_id == 0:
                recv_next = False
            next_backward_model_chunk_id -= 1
        else:
313
314
            next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
                                                              forward=False)
315

316
317
        # If last iteration, don't receive; we already received one extra
        # before the start of the for loop.
318
319
320
321
322
        if k == (num_microbatches_remaining - 1):
            recv_prev = False

        # Communicate tensors.
        input_tensor, output_tensor_grad = \
323
            p2p_communication.send_forward_backward_recv_forward_backward(
324
325
326
327
                    output_tensor, input_tensor_grad,
                    recv_prev=recv_prev, recv_next=recv_next,
                    timers=timers)

328
329
        # Put input_tensor and output_tensor_grad in data structures in the
        # right location.
330
331
332
        if recv_prev:
            input_tensors[next_forward_model_chunk_id].append(input_tensor)
        if recv_next:
333
334
            output_tensor_grads[next_backward_model_chunk_id].append(
                output_tensor_grad)
335

336
    # Run cooldown backward passes (flush out pipeline).
337
338
339
    if not forward_only:
        if all_warmup_microbatches:
            output_tensor_grads[num_model_chunks-1].append(
340
                p2p_communication.recv_backward(timers))
341
342
343
344
345
346
347
348
349
350
        for k in range(num_microbatches_remaining, num_microbatches):
            input_tensor_grad = backward_step_helper(k)
            next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
            recv_next = True
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                if next_backward_model_chunk_id == (num_model_chunks - 1):
                    recv_next = False
            if k == (num_microbatches - 1):
                recv_next = False
            output_tensor_grads[next_backward_model_chunk_id].append(
351
352
                p2p_communication.send_backward_recv_backward(
                    input_tensor_grad, recv_next, timers))
353
354
355
356

    return losses_reduced


357
358
359
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
                                                     model, optimizer, timers,
                                                     forward_only):
360
361
362
363
    """Run non-interleaved 1F1B schedule, with communication between pipeline
    stages.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
364
365
    timers = get_timers()

366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    assert len(model) == 1
    model = model[0]

    # Compute number of warmup microbatches.
    num_microbatches = get_num_microbatches()
    num_warmup_microbatches = \
        (mpu.get_pipeline_model_parallel_world_size() -
         mpu.get_pipeline_model_parallel_rank() - 1)
    num_warmup_microbatches = min(
        num_warmup_microbatches,
        num_microbatches)
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

    input_tensors = []
    output_tensors = []
    losses_reduced = []

    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
386
        input_tensor = p2p_communication.recv_forward(timers)
387
388
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
389
        p2p_communication.send_forward(output_tensor, timers)
390
391
392
393
394
395
396
397

        input_tensors.append(input_tensor)
        output_tensors.append(output_tensor)

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
398
        input_tensor = p2p_communication.recv_forward(timers)
399
400
401
402
403
404
405
406

    # Run 1F1B in steady state.
    for i in range(num_microbatches_remaining):
        last_iteration = (i == (num_microbatches_remaining - 1))

        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
        if forward_only:
407
            p2p_communication.send_forward(output_tensor, timers)
408
        else:
409
            output_tensor_grad = \
410
411
                p2p_communication.send_forward_recv_backward(output_tensor,
                                                             timers)
412
413
414
415
416
417
418
419

        # Add input_tensor and output_tensor to end of list, then pop from the
        # start of the list for backward pass.
        input_tensors.append(input_tensor)
        output_tensors.append(output_tensor)

        if forward_only:
            if not last_iteration:
420
                input_tensor = p2p_communication.recv_forward(timers)
421
422
423
424
425
426
427
428
429
        else:
            input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

            if last_iteration:
                input_tensor = None
430
                p2p_communication.send_backward(input_tensor_grad, timers)
431
            else:
432
                input_tensor = \
433
434
                    p2p_communication.send_backward_recv_forward(
                        input_tensor_grad, timers)
435
436
437
438
439
440
441

    # Run cooldown backward passes.
    if not forward_only:
        for i in range(num_warmup_microbatches):
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

442
            output_tensor_grad = p2p_communication.recv_backward(timers)
443
444
445
446
447

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

448
            p2p_communication.send_backward(input_tensor_grad, timers)
449
450

    return losses_reduced