"src/targets/vscode:/vscode.git/clone" did not exist on "3ea86a4d4602a6efbb430f74f9733d1094375732"
schedules.py 28.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
from contextlib import contextmanager
17
import torch
18
from torch.autograd.variable import Variable
19
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
20
21

from megatron import get_args
22
from megatron import get_num_microbatches
23
24
from megatron import get_timers
from megatron import mpu
25
from megatron import p2p_communication
26
27
28
from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
29
30
from megatron.model import ModelType

31

Jared Casper's avatar
Jared Casper committed
32
33
34
35
36
def get_forward_backward_func():
    args = get_args()
    if mpu.get_pipeline_model_parallel_world_size() > 1:
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
37
38
39
40
41
42
43
            # >>>
            # from lutil import pax
            # pax({
            #     "num microbatches" : get_num_microbatches(),
            #     "pipeline size" : args.pipeline_model_parallel_size,
            # })
            # <<<
Lawrence McAfee's avatar
Lawrence McAfee committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
            # >>>
            # assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
            #     'number of microbatches is not divisible by pipeline-parallel ' \
            #     'size when using interleaved schedule'
            # assert get_num_microbatches() % \
            #     args.transformer_pipeline_model_parallel_size == 0, \
            #     'number of microbatches (%d) is not divisible by transformer-' \
            #     'pipeline-model-parallel-size (%d) when using interleaved ' \
            #     'schedule' % (
            #         get_num_microbatches(),
            #         args.transformer_pipeline_model_parallel_size,
            #     )
            assert get_num_microbatches() % \
                args.pipeline_model_parallel_size == 0, \
                'number of microbatches (%d) is not divisible by pipeline-' \
                'model-parallel-size (%d) when using interleaved schedule' % (
                    get_num_microbatches(),
                    args.pipeline_model_parallel_size,
                )
            # <<<
Jared Casper's avatar
Jared Casper committed
64
65
66
67
68
69
        else:
            forward_backward_func = forward_backward_pipelining_without_interleaving
    else:
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func

70
71
def deallocate_output_tensor(out):
    '''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
72
73
74
75
76

    This method should be called right after the output tensor has been
    sent to the next pipeline stage. At this point, the output tensor is
    only useful for its '.grad_fn' field, and not its '.data'.
    '''
Lawrence McAfee's avatar
Lawrence McAfee committed
77
    if out is None:
78
        return
79
80
81
82
83
84
85
86
87
    assert isinstance(out, torch.Tensor), \
        "expected Tensor, found %s." % type(out).__name__
    assert out._base is None, \
        "counter-productive to free a view of another tensor."
    out.data = torch.empty(
        (1,),
        device = out.device,
        dtype = out.dtype,
    )
88
        
89
def custom_backward(output, grad_output):
90
91
    '''Directly call C++ autograd engine.

92
    To make the 'deallocate_output_tensor' (above) optimization work, the C++
93
94
95
96
    autograd engine must be called directly, bypassing Pytorch's
    torch.autograd.backward. Pytorch's 'backward' checks that the output and
    grad have the same shape, while C++'s 'backward' does not.
    '''
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

    assert output.numel() == 1, \
        "output should be pseudo-'freed' in schedule, to optimize memory"
    assert isinstance(output, torch.Tensor), \
        "output == '%s'." % type(output).__name__
    assert isinstance(grad_output, (torch.Tensor, type(None))), \
        "grad_output == '%s'." % type(grad_output).__name__

    # Handle scalar output
    if grad_output is None:
        assert output.numel() == 1, "implicit grad requires scalar output."
        grad_output = torch.ones_like(
            output,
            memory_format = torch.preserve_format,
        )

    # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
    Variable._execution_engine.run_backward(
        tensors = (output,),
        grad_tensors = (grad_output,),
        keep_graph = False,
        create_graph = False,
        inputs = tuple(),
        allow_unreachable=True,
        accumulate_grad=True,
    )
123
        
Jared Casper's avatar
Jared Casper committed
124

125
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
126
127
128
129
130
131
    """Forward step for passed-in model.

    If first stage, input tensor is obtained from data_iterator, otherwise
    passed-in input_tensor is used.

    Returns output tensor."""
132
    args = get_args()
133
134
135
    timers = get_timers()

    timers('forward-compute').start()
136
137
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
138
139
140
141
142
143

    unwrap_output_tensor = False
    if not isinstance(input_tensor, list):
        input_tensor = [input_tensor]
        unwrap_output_tensor = True

144
    unwrapped_model.set_input_tensor(input_tensor)
145
    output_tensor, loss_func = forward_step_func(data_iterator, model)
146
    if mpu.is_pipeline_last_stage():
147
        output_tensor = loss_func(output_tensor)
148
149
150
151
152
        loss, loss_reduced = output_tensor
        output_tensor = loss / get_num_microbatches()
        losses_reduced.append(loss_reduced)
    timers('forward-compute').stop()

153
154
155
    # If T5 model (or other model with encoder and decoder)
    # and in decoder stack, then send encoder_hidden_state
    # downstream as well.
156
157
158
159
160
161
    if mpu.is_pipeline_stage_after_split() and \
            args.model_type == ModelType.encoder_and_decoder:
        return [output_tensor, input_tensor[-1]]
    if unwrap_output_tensor:
        return output_tensor
    return [output_tensor]
162
163
164


def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
165
166
167
168
169
170
171
    """Backward step through passed-in output tensor.

    If last stage, output_tensor_grad is None, otherwise gradient of loss
    with respect to stage's output tensor.

    Returns gradient of loss with respect to input tensor (None if first
    stage)."""
172
173
174
175

    # NOTE: This code currently can handle at most one skip connection. It
    # needs to be modified slightly to support arbitrary numbers of skip
    # connections.
176
177
178
179
180
181
    args = get_args()

    timers = get_timers()
    timers('backward-compute').start()

    # Retain the grad on the input_tensor.
182
183
184
185
186
187
188
189
190
191
192
193
    unwrap_input_tensor_grad = False
    if not isinstance(input_tensor, list):
        input_tensor = [input_tensor]
        unwrap_input_tensor_grad = True
    for x in input_tensor:
        if x is not None:
            x.retain_grad()

    if not isinstance(output_tensor, list):
        output_tensor = [output_tensor]
    if not isinstance(output_tensor_grad, list):
        output_tensor_grad = [output_tensor_grad]
194
195

    # Backward pass.
196
197
    if output_tensor_grad[0] is None:
        output_tensor = optimizer.scale_loss(output_tensor[0])
198
    custom_backward(output_tensor[0], output_tensor_grad[0])
199
200

    # Collect the grad of the input_tensor.
201
    input_tensor_grad = [None]
202
    if input_tensor is not None:
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        input_tensor_grad = []
        for x in input_tensor:
            if x is None:
                input_tensor_grad.append(None)
            else:
                input_tensor_grad.append(x.grad)

    # Handle single skip connection if it exists (encoder_hidden_state in
    # model with encoder and decoder).
    if mpu.get_pipeline_model_parallel_world_size() > 1 and \
            mpu.is_pipeline_stage_after_split() and \
            args.model_type == ModelType.encoder_and_decoder:
        if output_tensor_grad[1] is not None:
            input_tensor_grad[-1].add_(output_tensor_grad[1])
    if unwrap_input_tensor_grad:
        input_tensor_grad = input_tensor_grad[0]
219
220
221
222
223
224

    timers('backward-compute').stop()

    return input_tensor_grad


225
226
227
228
229
230
231
232
@contextmanager
def dummy_handler():
    try:
        yield
    finally:
        pass


233
234
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
                                   optimizer, timers, forward_only):
235
236
237
238
    """Run forward and backward passes with no pipeline parallelism
    (no inter-stage communication).

    Returns dictionary with losses."""
239
240
241
    assert len(model) == 1
    model = model[0]

242
243
244
245
    context_handler = dummy_handler
    if isinstance(model, torchDDP):
        context_handler = model.no_sync

246
    losses_reduced = []
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    input_tensor, output_tensor_grad = None, None
    with context_handler():
        for i in range(get_num_microbatches() - 1):
            output_tensor = forward_step(forward_step_func, data_iterator, model,
                                         input_tensor, losses_reduced)
            if not forward_only:
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

    # Run computation for last microbatch out of context handler (want to
    # synchronize gradients).
    output_tensor = forward_step(forward_step_func, data_iterator, model,
                                 input_tensor, losses_reduced)
    if not forward_only:
        backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
262
263
264
265
266
267

    return losses_reduced


def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model,
                                                  optimizer, timers, forward_only):
268
269
270
271
    """Run interleaved 1F1B schedule (model split into model chunks), with
    communication between pipeline stages as needed.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
272
273
274
275
276
277
278
    input_tensors = [[] for _ in range(len(model))]
    output_tensors = [[] for _ in range(len(model))]
    losses_reduced = []
    if not forward_only:
        output_tensor_grads = [[] for _ in range(len(model))]

    pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
279
    pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
280

281
282
283
    args = get_args()
    tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)

284
285
286
287
288
289
290
    # Compute number of warmup and remaining microbatches.
    num_model_chunks = len(model)
    num_microbatches = get_num_microbatches() * num_model_chunks
    all_warmup_microbatches = False
    if forward_only:
        num_warmup_microbatches = num_microbatches
    else:
291
292
293
294
295
296
        # Run all forward passes and then all backward passes if number of
        # microbatches is just the number of pipeline stages.
        # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
        # all workers, followed by more microbatches after depending on
        # stage ID (more forward passes for earlier stages, later stages can
        # immediately start with 1F1B).
297
298
299
300
301
        if get_num_microbatches() == pipeline_parallel_size:
            num_warmup_microbatches = num_microbatches
            all_warmup_microbatches = True
        else:
            num_warmup_microbatches = \
302
                (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
303
304
305
306
            num_warmup_microbatches += (
                num_model_chunks - 1) * pipeline_parallel_size
            num_warmup_microbatches = min(num_warmup_microbatches,
                                          num_microbatches)
307
308
309
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

310
    def get_model_chunk_id(microbatch_id, forward):
311
        """Helper method to get the model chunk ID given the iteration number."""
312
313
        microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
        model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
314
        if not forward:
315
316
            model_chunk_id = (num_model_chunks - model_chunk_id - 1)
        return model_chunk_id
317

318
    def forward_step_helper(microbatch_id):
319
320
321
        """Helper method to run forward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        forward_step())."""
322
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
323
324
        mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

325
        # forward step
326
        if mpu.is_pipeline_first_stage():
327
328
            if len(input_tensors[model_chunk_id]) == \
                    len(output_tensors[model_chunk_id]):
329
330
                input_tensors[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id][-1]
331
332
        output_tensor = forward_step(forward_step_func,
                                     data_iterator[model_chunk_id],
333
334
335
336
                                     model[model_chunk_id],
                                     input_tensor, losses_reduced)
        output_tensors[model_chunk_id].append(output_tensor)

337
338
339
340
341
        # if forward-only, no need to save tensors for a backward pass
        if forward_only:
            input_tensors[model_chunk_id].pop()
            output_tensors[model_chunk_id].pop()

342
343
        return output_tensor

344
    def backward_step_helper(microbatch_id):
345
346
347
        """Helper method to run backward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        backward_step())."""
348
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
349
350
351
352
353
354
355
356
357
        mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        if mpu.is_pipeline_last_stage():
            if len(output_tensor_grads[model_chunk_id]) == 0:
                output_tensor_grads[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id].pop(0)
        output_tensor = output_tensors[model_chunk_id].pop(0)
        output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
        input_tensor_grad = \
358
359
360
361
            backward_step(optimizer,
                          input_tensor,
                          output_tensor,
                          output_tensor_grad)
362
363
364
365
366

        return input_tensor_grad

    # Run warmup forward passes.
    mpu.set_virtual_pipeline_model_parallel_rank(0)
367
    input_tensors[0].append(
368
        p2p_communication.recv_forward(tensor_shape, timers=timers))
369
370
    for k in range(num_warmup_microbatches):
        output_tensor = forward_step_helper(k)
371
372

        # Determine if tensor should be received from previous stage.
373
374
375
376
377
378
379
        next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
        recv_prev = True
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            if next_forward_model_chunk_id == 0:
                recv_prev = False
        if k == (num_microbatches - 1):
            recv_prev = False
380
381

        # Don't send tensor downstream if on last stage.
382
383
        if mpu.is_pipeline_last_stage():
            output_tensor = None
384
385
386

        # Send and receive tensors as appropriate (send tensors computed
        # in this iteration; receive tensors for next iteration).
387
388
389
390
391
392
393
        if k == (num_warmup_microbatches - 1) and not forward_only and \
                not all_warmup_microbatches:
            input_tensor_grad = None
            recv_next = True
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                recv_next = False
            input_tensor, output_tensor_grad = \
394
                p2p_communication.send_forward_backward_recv_forward_backward(
395
396
                        output_tensor, input_tensor_grad,
                        recv_prev=recv_prev, recv_next=recv_next,
397
                        tensor_shape=tensor_shape,
398
399
400
                        timers=timers)
            output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
        else:
401
            input_tensor = \
402
                p2p_communication.send_forward_recv_forward(
403
404
405
                    output_tensor, recv_prev=recv_prev,
                    tensor_shape=tensor_shape,
                    timers=timers)
406
        input_tensors[next_forward_model_chunk_id].append(input_tensor)
407
        deallocate_output_tensor(output_tensor)
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444

    # Run 1F1B in steady state.
    for k in range(num_microbatches_remaining):
        # Forward pass.
        forward_k = k + num_warmup_microbatches
        output_tensor = forward_step_helper(forward_k)

        # Backward pass.
        backward_k = k
        input_tensor_grad = backward_step_helper(backward_k)

        # Send output_tensor and input_tensor_grad, receive input_tensor
        # and output_tensor_grad.

        # Determine if current stage has anything to send in either direction,
        # otherwise set tensor to None.
        forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
        mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
        if mpu.is_pipeline_last_stage():
            output_tensor = None

        backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
        mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
        if mpu.is_pipeline_first_stage():
            input_tensor_grad = None

        # Determine if peers are sending, and where in data structure to put
        # received tensors.
        recv_prev = True
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            # First stage is ahead of last stage by (pipeline_parallel_size - 1).
            next_forward_model_chunk_id = get_model_chunk_id(
                forward_k - (pipeline_parallel_size - 1), forward=True)
            if next_forward_model_chunk_id == (num_model_chunks - 1):
                recv_prev = False
            next_forward_model_chunk_id += 1
        else:
445
446
            next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
                                                             forward=True)
447
448
449
450
451
452
453
454
455
456

        recv_next = True
        if mpu.is_pipeline_last_stage(ignore_virtual=True):
            # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
            next_backward_model_chunk_id = get_model_chunk_id(
                backward_k - (pipeline_parallel_size - 1), forward=False)
            if next_backward_model_chunk_id == 0:
                recv_next = False
            next_backward_model_chunk_id -= 1
        else:
457
458
            next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
                                                              forward=False)
459

460
461
        # If last iteration, don't receive; we already received one extra
        # before the start of the for loop.
462
463
464
465
466
        if k == (num_microbatches_remaining - 1):
            recv_prev = False

        # Communicate tensors.
        input_tensor, output_tensor_grad = \
467
            p2p_communication.send_forward_backward_recv_forward_backward(
468
469
                    output_tensor, input_tensor_grad,
                    recv_prev=recv_prev, recv_next=recv_next,
470
                    tensor_shape=tensor_shape, timers=timers)
471
        deallocate_output_tensor(output_tensor)
472

473
474
        # Put input_tensor and output_tensor_grad in data structures in the
        # right location.
475
476
477
        if recv_prev:
            input_tensors[next_forward_model_chunk_id].append(input_tensor)
        if recv_next:
478
479
            output_tensor_grads[next_backward_model_chunk_id].append(
                output_tensor_grad)
480

481
    # Run cooldown backward passes (flush out pipeline).
482
483
484
    if not forward_only:
        if all_warmup_microbatches:
            output_tensor_grads[num_model_chunks-1].append(
485
                p2p_communication.recv_backward(tensor_shape, timers=timers))
486
487
488
489
490
491
492
493
494
495
        for k in range(num_microbatches_remaining, num_microbatches):
            input_tensor_grad = backward_step_helper(k)
            next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
            recv_next = True
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                if next_backward_model_chunk_id == (num_model_chunks - 1):
                    recv_next = False
            if k == (num_microbatches - 1):
                recv_next = False
            output_tensor_grads[next_backward_model_chunk_id].append(
496
                p2p_communication.send_backward_recv_backward(
497
498
499
                    input_tensor_grad, recv_next=recv_next,
                    tensor_shape=tensor_shape,
                    timers=timers))
500
501
502
503

    return losses_reduced


504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
def get_tensor_shapes(rank, model_type):
    # Determine right tensor sizes (based on position of rank with respect to split
    # rank) and model size.
    # Send two tensors if model is T5 and rank is in decoder stage:
    #     first tensor is decoder (pre-transpose),
    #     second tensor is encoder (post-transpose).
    # If model is T5 and rank is at the boundary:
    #     send one tensor (post-transpose from encoder).
    # Otherwise, send one tensor (pre-transpose).
    args = get_args()
    tensor_shapes = []
    if model_type == ModelType.encoder_and_decoder:
        if mpu.is_pipeline_stage_before_split(rank):
            # If next rank is after split, then need transpose for encoder_hidden_state.
            if mpu.is_pipeline_stage_before_split(rank+1):
                tensor_shapes.append((args.seq_length, args.micro_batch_size, args.hidden_size))
            else:
                tensor_shapes.append((args.micro_batch_size, args.seq_length, args.hidden_size))
        else:
            tensor_shapes.append((args.decoder_seq_length, args.micro_batch_size, args.hidden_size))
            tensor_shapes.append((args.micro_batch_size, args.seq_length, args.hidden_size))
    else:
        tensor_shapes.append((args.seq_length, args.micro_batch_size, args.hidden_size))
    return tensor_shapes


def recv_forward(tensor_shapes, timers):
    input_tensors = []
    for tensor_shape in tensor_shapes:
        if tensor_shape is None:
            input_tensors.append(None)
        else:
            input_tensors.append(p2p_communication.recv_forward(tensor_shape,
                                                                timers=timers))
    return input_tensors


def recv_backward(tensor_shapes, timers):
    output_tensor_grads = []
    for tensor_shape in tensor_shapes:
        if tensor_shape is None:
            output_tensor_grads.append(None)
        else:
            output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape,
                                                                       timers=timers))
    return output_tensor_grads


def send_forward(output_tensors, tensor_shapes, timers):
    if not isinstance(output_tensors, list):
        output_tensors = [output_tensors]
    for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
        if tensor_shape is None:
            continue
        p2p_communication.send_forward(output_tensor, tensor_shape, timers=timers)


def send_backward(input_tensor_grads, tensor_shapes, timers):
    if not isinstance(input_tensor_grads, list):
        input_tensor_grads = [input_tensor_grads]
    for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
        if tensor_shape is None:
            continue
        p2p_communication.send_backward(input_tensor_grad, tensor_shape, timers=timers)


def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
    if not isinstance(output_tensors, list):
        output_tensors = [output_tensors]
    output_tensor_grads = []
    for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
        if tensor_shape is None:
            output_tensor_grads.append(None)
            continue
        output_tensor_grad = p2p_communication.send_forward_recv_backward(
                output_tensor, tensor_shape, timers=timers)
        output_tensor_grads.append(output_tensor_grad)
    return output_tensor_grads


def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
    if not isinstance(input_tensor_grads, list):
        input_tensor_grads = [input_tensor_grads]
    input_tensors = []
    for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
        if tensor_shape is None:
            input_tensors.append(None)
            continue
        input_tensor = p2p_communication.send_backward_recv_forward(
                input_tensor_grad, tensor_shape, timers=timers)
        input_tensors.append(input_tensor)
    return input_tensors


598
599
600
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
                                                     model, optimizer, timers,
                                                     forward_only):
601
602
603
604
    """Run non-interleaved 1F1B schedule, with communication between pipeline
    stages.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
605
    args = get_args()
606
607
    timers = get_timers()

608
609
610
611
612
613
614
615
616
617
618
619
620
621
    assert len(model) == 1
    model = model[0]

    # Compute number of warmup microbatches.
    num_microbatches = get_num_microbatches()
    num_warmup_microbatches = \
        (mpu.get_pipeline_model_parallel_world_size() -
         mpu.get_pipeline_model_parallel_rank() - 1)
    num_warmup_microbatches = min(
        num_warmup_microbatches,
        num_microbatches)
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

622
623
624
625
626
627
628
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    model_type = unwrapped_model.model_type
    rank = mpu.get_pipeline_model_parallel_rank()
    recv_tensor_shapes = get_tensor_shapes(rank-1, model_type)
    send_tensor_shapes = get_tensor_shapes(rank, model_type)

629
630
631
632
633
634
    # Input, output tensors only need to be saved when doing backward passes
    input_tensors = None
    output_tensors = None
    if not forward_only:
        input_tensors = []
        output_tensors = []
635
636
637
638
    losses_reduced = []

    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
639
        input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
640
641
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
642
        send_forward(output_tensor, send_tensor_shapes, timers=timers)
643

644
645
646
        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
647
            deallocate_output_tensor(output_tensor[0])
648
649
650
651
652

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
653
        input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
654
655
656
657
658
659
660
661

    # Run 1F1B in steady state.
    for i in range(num_microbatches_remaining):
        last_iteration = (i == (num_microbatches_remaining - 1))

        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
        if forward_only:
662
            send_forward(output_tensor, send_tensor_shapes, timers=timers)
663
664

            if not last_iteration:
665
                input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
666

667
        else:
668
            output_tensor_grad = \
669
670
671
                send_forward_recv_backward(output_tensor,
                                           send_tensor_shapes,
                                           timers=timers)
672

673
674
675
            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
676
            deallocate_output_tensor(output_tensor[0])
677

678
679
680
681
            # Pop input_tensor and output_tensor from the start of the list for
            # the backward pass.
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)
682
683
684
685
686
687
688

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

            if last_iteration:
                input_tensor = None
689
                send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
690
            else:
691
                input_tensor = \
692
693
                    send_backward_recv_forward(
                        input_tensor_grad, recv_tensor_shapes, timers=timers)
694
695
696
697
698
699
700

    # Run cooldown backward passes.
    if not forward_only:
        for i in range(num_warmup_microbatches):
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

701
            output_tensor_grad = recv_backward(send_tensor_shapes, timers=timers)
702
703
704
705
706

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

707
            send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
708
709

    return losses_reduced