schedules.py 24.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
from contextlib import contextmanager
17
import torch
18
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
19
20

from megatron import get_args
21
from megatron import get_num_microbatches
22
23
from megatron import get_timers
from megatron import mpu
24
from megatron import p2p_communication
25
26
27
from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
28
29
from megatron.model import ModelType

30

Jared Casper's avatar
Jared Casper committed
31
32
33
34
35
def get_forward_backward_func():
    args = get_args()
    if mpu.get_pipeline_model_parallel_world_size() > 1:
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
36
37
38
            assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
                'number of microbatches is not divisible by pipeline-parallel ' \
                'size when using interleaved schedule'
Jared Casper's avatar
Jared Casper committed
39
40
41
42
43
44
45
        else:
            forward_backward_func = forward_backward_pipelining_without_interleaving
    else:
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func


46
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
47
48
49
50
51
52
    """Forward step for passed-in model.

    If first stage, input tensor is obtained from data_iterator, otherwise
    passed-in input_tensor is used.

    Returns output tensor."""
53
    args = get_args()
54
55
56
    timers = get_timers()

    timers('forward-compute').start()
57
58
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
59
60
61
62
63
64

    unwrap_output_tensor = False
    if not isinstance(input_tensor, list):
        input_tensor = [input_tensor]
        unwrap_output_tensor = True

65
    unwrapped_model.set_input_tensor(input_tensor)
66
    output_tensor, loss_func = forward_step_func(data_iterator, model)
67
    if mpu.is_pipeline_last_stage():
68
        output_tensor = loss_func(output_tensor)
69
70
71
72
73
        loss, loss_reduced = output_tensor
        output_tensor = loss / get_num_microbatches()
        losses_reduced.append(loss_reduced)
    timers('forward-compute').stop()

74
75
76
77
78
79
    if mpu.is_pipeline_stage_after_split() and \
            args.model_type == ModelType.encoder_and_decoder:
        return [output_tensor, input_tensor[-1]]
    if unwrap_output_tensor:
        return output_tensor
    return [output_tensor]
80
81
82


def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
83
84
85
86
87
88
89
    """Backward step through passed-in output tensor.

    If last stage, output_tensor_grad is None, otherwise gradient of loss
    with respect to stage's output tensor.

    Returns gradient of loss with respect to input tensor (None if first
    stage)."""
90
91
92
93

    # NOTE: This code currently can handle at most one skip connection. It
    # needs to be modified slightly to support arbitrary numbers of skip
    # connections.
94
95
96
97
98
99
    args = get_args()

    timers = get_timers()
    timers('backward-compute').start()

    # Retain the grad on the input_tensor.
100
101
102
103
104
105
106
107
108
109
110
111
    unwrap_input_tensor_grad = False
    if not isinstance(input_tensor, list):
        input_tensor = [input_tensor]
        unwrap_input_tensor_grad = True
    for x in input_tensor:
        if x is not None:
            x.retain_grad()

    if not isinstance(output_tensor, list):
        output_tensor = [output_tensor]
    if not isinstance(output_tensor_grad, list):
        output_tensor_grad = [output_tensor_grad]
112
113

    # Backward pass.
114
115
116
    if output_tensor_grad[0] is None:
        output_tensor = optimizer.scale_loss(output_tensor[0])
    torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
117
118

    # Collect the grad of the input_tensor.
119
    input_tensor_grad = [None]
120
    if input_tensor is not None:
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        input_tensor_grad = []
        for x in input_tensor:
            if x is None:
                input_tensor_grad.append(None)
            else:
                input_tensor_grad.append(x.grad)

    # Handle single skip connection if it exists (encoder_hidden_state in
    # model with encoder and decoder).
    if mpu.get_pipeline_model_parallel_world_size() > 1 and \
            mpu.is_pipeline_stage_after_split() and \
            args.model_type == ModelType.encoder_and_decoder:
        if output_tensor_grad[1] is not None:
            input_tensor_grad[-1].add_(output_tensor_grad[1])
    if unwrap_input_tensor_grad:
        input_tensor_grad = input_tensor_grad[0]
137
138
139
140
141
142

    timers('backward-compute').stop()

    return input_tensor_grad


143
144
145
146
147
148
149
150
@contextmanager
def dummy_handler():
    try:
        yield
    finally:
        pass


151
152
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
                                   optimizer, timers, forward_only):
153
154
155
156
    """Run forward and backward passes with no pipeline parallelism
    (no inter-stage communication).

    Returns dictionary with losses."""
157
158
159
    assert len(model) == 1
    model = model[0]

160
161
162
163
    context_handler = dummy_handler
    if isinstance(model, torchDDP):
        context_handler = model.no_sync

164
    losses_reduced = []
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    input_tensor, output_tensor_grad = None, None
    with context_handler():
        for i in range(get_num_microbatches() - 1):
            output_tensor = forward_step(forward_step_func, data_iterator, model,
                                         input_tensor, losses_reduced)
            if not forward_only:
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

    # Run computation for last microbatch out of context handler (want to
    # synchronize gradients).
    output_tensor = forward_step(forward_step_func, data_iterator, model,
                                 input_tensor, losses_reduced)
    if not forward_only:
        backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
180
181
182
183
184
185

    return losses_reduced


def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model,
                                                  optimizer, timers, forward_only):
186
187
188
189
    """Run interleaved 1F1B schedule (model split into model chunks), with
    communication between pipeline stages as needed.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
190
191
192
193
194
195
196
    input_tensors = [[] for _ in range(len(model))]
    output_tensors = [[] for _ in range(len(model))]
    losses_reduced = []
    if not forward_only:
        output_tensor_grads = [[] for _ in range(len(model))]

    pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
197
    pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
198

199
200
201
    args = get_args()
    tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)

202
203
204
205
206
207
208
    # Compute number of warmup and remaining microbatches.
    num_model_chunks = len(model)
    num_microbatches = get_num_microbatches() * num_model_chunks
    all_warmup_microbatches = False
    if forward_only:
        num_warmup_microbatches = num_microbatches
    else:
209
210
211
212
213
214
        # Run all forward passes and then all backward passes if number of
        # microbatches is just the number of pipeline stages.
        # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
        # all workers, followed by more microbatches after depending on
        # stage ID (more forward passes for earlier stages, later stages can
        # immediately start with 1F1B).
215
216
217
218
219
        if get_num_microbatches() == pipeline_parallel_size:
            num_warmup_microbatches = num_microbatches
            all_warmup_microbatches = True
        else:
            num_warmup_microbatches = \
220
                (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
221
222
223
224
            num_warmup_microbatches += (
                num_model_chunks - 1) * pipeline_parallel_size
            num_warmup_microbatches = min(num_warmup_microbatches,
                                          num_microbatches)
225
226
227
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

228
    def get_model_chunk_id(microbatch_id, forward):
229
        """Helper method to get the model chunk ID given the iteration number."""
230
231
        microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
        model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
232
        if not forward:
233
234
            model_chunk_id = (num_model_chunks - model_chunk_id - 1)
        return model_chunk_id
235

236
    def forward_step_helper(microbatch_id):
237
238
239
        """Helper method to run forward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        forward_step())."""
240
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
241
242
243
        mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        if mpu.is_pipeline_first_stage():
244
245
            if len(input_tensors[model_chunk_id]) == \
                    len(output_tensors[model_chunk_id]):
246
247
                input_tensors[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id][-1]
248
249
        output_tensor = forward_step(forward_step_func,
                                     data_iterator[model_chunk_id],
250
251
252
253
254
255
                                     model[model_chunk_id],
                                     input_tensor, losses_reduced)
        output_tensors[model_chunk_id].append(output_tensor)

        return output_tensor

256
    def backward_step_helper(microbatch_id):
257
258
259
        """Helper method to run backward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        backward_step())."""
260
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
261
262
263
264
265
266
267
268
269
        mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        if mpu.is_pipeline_last_stage():
            if len(output_tensor_grads[model_chunk_id]) == 0:
                output_tensor_grads[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id].pop(0)
        output_tensor = output_tensors[model_chunk_id].pop(0)
        output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
        input_tensor_grad = \
270
271
272
273
            backward_step(optimizer,
                          input_tensor,
                          output_tensor,
                          output_tensor_grad)
274
275
276
277
278

        return input_tensor_grad

    # Run warmup forward passes.
    mpu.set_virtual_pipeline_model_parallel_rank(0)
279
    input_tensors[0].append(
280
        p2p_communication.recv_forward(tensor_shape, timers=timers))
281
282
    for k in range(num_warmup_microbatches):
        output_tensor = forward_step_helper(k)
283
284

        # Determine if tensor should be received from previous stage.
285
286
287
288
289
290
291
        next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
        recv_prev = True
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            if next_forward_model_chunk_id == 0:
                recv_prev = False
        if k == (num_microbatches - 1):
            recv_prev = False
292
293

        # Don't send tensor downstream if on last stage.
294
295
        if mpu.is_pipeline_last_stage():
            output_tensor = None
296
297
298

        # Send and receive tensors as appropriate (send tensors computed
        # in this iteration; receive tensors for next iteration).
299
300
301
302
303
304
305
        if k == (num_warmup_microbatches - 1) and not forward_only and \
                not all_warmup_microbatches:
            input_tensor_grad = None
            recv_next = True
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                recv_next = False
            input_tensor, output_tensor_grad = \
306
                p2p_communication.send_forward_backward_recv_forward_backward(
307
308
                        output_tensor, input_tensor_grad,
                        recv_prev=recv_prev, recv_next=recv_next,
309
                        tensor_shape=tensor_shape,
310
311
312
                        timers=timers)
            output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
        else:
313
            input_tensor = \
314
                p2p_communication.send_forward_recv_forward(
315
316
317
                    output_tensor, recv_prev=recv_prev,
                    tensor_shape=tensor_shape,
                    timers=timers)
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
        input_tensors[next_forward_model_chunk_id].append(input_tensor)

    # Run 1F1B in steady state.
    for k in range(num_microbatches_remaining):
        # Forward pass.
        forward_k = k + num_warmup_microbatches
        output_tensor = forward_step_helper(forward_k)

        # Backward pass.
        backward_k = k
        input_tensor_grad = backward_step_helper(backward_k)

        # Send output_tensor and input_tensor_grad, receive input_tensor
        # and output_tensor_grad.

        # Determine if current stage has anything to send in either direction,
        # otherwise set tensor to None.
        forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
        mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
        if mpu.is_pipeline_last_stage():
            output_tensor = None

        backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
        mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
        if mpu.is_pipeline_first_stage():
            input_tensor_grad = None

        # Determine if peers are sending, and where in data structure to put
        # received tensors.
        recv_prev = True
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            # First stage is ahead of last stage by (pipeline_parallel_size - 1).
            next_forward_model_chunk_id = get_model_chunk_id(
                forward_k - (pipeline_parallel_size - 1), forward=True)
            if next_forward_model_chunk_id == (num_model_chunks - 1):
                recv_prev = False
            next_forward_model_chunk_id += 1
        else:
356
357
            next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
                                                             forward=True)
358
359
360
361
362
363
364
365
366
367

        recv_next = True
        if mpu.is_pipeline_last_stage(ignore_virtual=True):
            # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
            next_backward_model_chunk_id = get_model_chunk_id(
                backward_k - (pipeline_parallel_size - 1), forward=False)
            if next_backward_model_chunk_id == 0:
                recv_next = False
            next_backward_model_chunk_id -= 1
        else:
368
369
            next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
                                                              forward=False)
370

371
372
        # If last iteration, don't receive; we already received one extra
        # before the start of the for loop.
373
374
375
376
377
        if k == (num_microbatches_remaining - 1):
            recv_prev = False

        # Communicate tensors.
        input_tensor, output_tensor_grad = \
378
            p2p_communication.send_forward_backward_recv_forward_backward(
379
380
                    output_tensor, input_tensor_grad,
                    recv_prev=recv_prev, recv_next=recv_next,
381
                    tensor_shape=tensor_shape, timers=timers)
382

383
384
        # Put input_tensor and output_tensor_grad in data structures in the
        # right location.
385
386
387
        if recv_prev:
            input_tensors[next_forward_model_chunk_id].append(input_tensor)
        if recv_next:
388
389
            output_tensor_grads[next_backward_model_chunk_id].append(
                output_tensor_grad)
390

391
    # Run cooldown backward passes (flush out pipeline).
392
393
394
    if not forward_only:
        if all_warmup_microbatches:
            output_tensor_grads[num_model_chunks-1].append(
395
                p2p_communication.recv_backward(tensor_shape, timers=timers))
396
397
398
399
400
401
402
403
404
405
        for k in range(num_microbatches_remaining, num_microbatches):
            input_tensor_grad = backward_step_helper(k)
            next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
            recv_next = True
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                if next_backward_model_chunk_id == (num_model_chunks - 1):
                    recv_next = False
            if k == (num_microbatches - 1):
                recv_next = False
            output_tensor_grads[next_backward_model_chunk_id].append(
406
                p2p_communication.send_backward_recv_backward(
407
408
409
                    input_tensor_grad, recv_next=recv_next,
                    tensor_shape=tensor_shape,
                    timers=timers))
410
411
412
413

    return losses_reduced


414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
def get_tensor_shapes(rank, model_type):
    # Determine right tensor sizes (based on position of rank with respect to split
    # rank) and model size.
    # Send two tensors if model is T5 and rank is in decoder stage:
    #     first tensor is decoder (pre-transpose),
    #     second tensor is encoder (post-transpose).
    # If model is T5 and rank is at the boundary:
    #     send one tensor (post-transpose from encoder).
    # Otherwise, send one tensor (pre-transpose).
    args = get_args()
    tensor_shapes = []
    if model_type == ModelType.encoder_and_decoder:
        if mpu.is_pipeline_stage_before_split(rank):
            # If next rank is after split, then need transpose for encoder_hidden_state.
            if mpu.is_pipeline_stage_before_split(rank+1):
                tensor_shapes.append((args.seq_length, args.micro_batch_size, args.hidden_size))
            else:
                tensor_shapes.append((args.micro_batch_size, args.seq_length, args.hidden_size))
        else:
            tensor_shapes.append((args.decoder_seq_length, args.micro_batch_size, args.hidden_size))
            tensor_shapes.append((args.micro_batch_size, args.seq_length, args.hidden_size))
    else:
        tensor_shapes.append((args.seq_length, args.micro_batch_size, args.hidden_size))
    return tensor_shapes


def recv_forward(tensor_shapes, timers):
    input_tensors = []
    for tensor_shape in tensor_shapes:
        if tensor_shape is None:
            input_tensors.append(None)
        else:
            input_tensors.append(p2p_communication.recv_forward(tensor_shape,
                                                                timers=timers))
    return input_tensors


def recv_backward(tensor_shapes, timers):
    output_tensor_grads = []
    for tensor_shape in tensor_shapes:
        if tensor_shape is None:
            output_tensor_grads.append(None)
        else:
            output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape,
                                                                       timers=timers))
    return output_tensor_grads


def send_forward(output_tensors, tensor_shapes, timers):
    if not isinstance(output_tensors, list):
        output_tensors = [output_tensors]
    for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
        if tensor_shape is None:
            continue
        p2p_communication.send_forward(output_tensor, tensor_shape, timers=timers)


def send_backward(input_tensor_grads, tensor_shapes, timers):
    if not isinstance(input_tensor_grads, list):
        input_tensor_grads = [input_tensor_grads]
    for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
        if tensor_shape is None:
            continue
        p2p_communication.send_backward(input_tensor_grad, tensor_shape, timers=timers)


def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
    if not isinstance(output_tensors, list):
        output_tensors = [output_tensors]
    output_tensor_grads = []
    for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
        if tensor_shape is None:
            output_tensor_grads.append(None)
            continue
        output_tensor_grad = p2p_communication.send_forward_recv_backward(
                output_tensor, tensor_shape, timers=timers)
        output_tensor_grads.append(output_tensor_grad)
    return output_tensor_grads


def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
    if not isinstance(input_tensor_grads, list):
        input_tensor_grads = [input_tensor_grads]
    input_tensors = []
    for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
        if tensor_shape is None:
            input_tensors.append(None)
            continue
        input_tensor = p2p_communication.send_backward_recv_forward(
                input_tensor_grad, tensor_shape, timers=timers)
        input_tensors.append(input_tensor)
    return input_tensors


508
509
510
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
                                                     model, optimizer, timers,
                                                     forward_only):
511
512
513
514
    """Run non-interleaved 1F1B schedule, with communication between pipeline
    stages.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
515
516
    timers = get_timers()

517
518
519
520
521
522
523
524
525
526
527
528
529
530
    assert len(model) == 1
    model = model[0]

    # Compute number of warmup microbatches.
    num_microbatches = get_num_microbatches()
    num_warmup_microbatches = \
        (mpu.get_pipeline_model_parallel_world_size() -
         mpu.get_pipeline_model_parallel_rank() - 1)
    num_warmup_microbatches = min(
        num_warmup_microbatches,
        num_microbatches)
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

531
532
533
534
535
536
537
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    model_type = unwrapped_model.model_type
    rank = mpu.get_pipeline_model_parallel_rank()
    recv_tensor_shapes = get_tensor_shapes(rank-1, model_type)
    send_tensor_shapes = get_tensor_shapes(rank, model_type)

538
539
540
541
542
543
    input_tensors = []
    output_tensors = []
    losses_reduced = []

    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
544
        input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
545
546
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
547
        send_forward(output_tensor, send_tensor_shapes, timers=timers)
548
549
550
551
552
553
554
555

        input_tensors.append(input_tensor)
        output_tensors.append(output_tensor)

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
556
        input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
557
558
559
560
561
562
563
564

    # Run 1F1B in steady state.
    for i in range(num_microbatches_remaining):
        last_iteration = (i == (num_microbatches_remaining - 1))

        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
        if forward_only:
565
            send_forward(output_tensor, send_tensor_shapes, timers=timers)
566
        else:
567
            output_tensor_grad = \
568
569
570
                send_forward_recv_backward(output_tensor,
                                           send_tensor_shapes,
                                           timers=timers)
571
572
573
574
575
576
577
        # Add input_tensor and output_tensor to end of list, then pop from the
        # start of the list for backward pass.
        input_tensors.append(input_tensor)
        output_tensors.append(output_tensor)

        if forward_only:
            if not last_iteration:
578
                input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
579
580
581
582
583
584
585
586
587
        else:
            input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

            if last_iteration:
                input_tensor = None
588
                send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
589
            else:
590
                input_tensor = \
591
592
                    send_backward_recv_forward(
                        input_tensor_grad, recv_tensor_shapes, timers=timers)
593
594
595
596
597
598
599

    # Run cooldown backward passes.
    if not forward_only:
        for i in range(num_warmup_microbatches):
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

600
            output_tensor_grad = recv_backward(send_tensor_shapes, timers=timers)
601
602
603
604
605

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

606
            send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
607
608

    return losses_reduced