schedules.py 29 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2

3
from contextlib import contextmanager
4
import torch
5
from torch.autograd.variable import Variable
6
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
7
8

from megatron import get_args
9
from megatron import get_num_microbatches
10
11
from megatron import get_timers
from megatron import mpu
12
from megatron import p2p_communication
13
14
15
from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
16
17
from megatron.model import ModelType

18

Jared Casper's avatar
Jared Casper committed
19
20
21
22
23
def get_forward_backward_func():
    args = get_args()
    if mpu.get_pipeline_model_parallel_world_size() > 1:
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
Lawrence McAfee's avatar
Lawrence McAfee committed
24
25
26
27
28
29
30
            assert get_num_microbatches() % \
                args.pipeline_model_parallel_size == 0, \
                'number of microbatches (%d) is not divisible by pipeline-' \
                'model-parallel-size (%d) when using interleaved schedule' % (
                    get_num_microbatches(),
                    args.pipeline_model_parallel_size,
                )
Jared Casper's avatar
Jared Casper committed
31
32
33
34
35
36
        else:
            forward_backward_func = forward_backward_pipelining_without_interleaving
    else:
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func

37
38
def deallocate_output_tensor(out):
    '''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
39
40
41
42
43

    This method should be called right after the output tensor has been
    sent to the next pipeline stage. At this point, the output tensor is
    only useful for its '.grad_fn' field, and not its '.data'.
    '''
Lawrence McAfee's avatar
Lawrence McAfee committed
44
45
    if out is None:
        return
46
47
48
49
50
51
52
53
54
    assert isinstance(out, torch.Tensor), \
        "expected Tensor, found %s." % type(out).__name__
    assert out._base is None, \
        "counter-productive to free a view of another tensor."
    out.data = torch.empty(
        (1,),
        device = out.device,
        dtype = out.dtype,
    )
55
        
56
def custom_backward(output, grad_output):
57
58
    '''Directly call C++ autograd engine.

59
    To make the 'deallocate_output_tensor' (above) optimization work, the C++
60
61
62
63
    autograd engine must be called directly, bypassing Pytorch's
    torch.autograd.backward. Pytorch's 'backward' checks that the output and
    grad have the same shape, while C++'s 'backward' does not.
    '''
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

    assert output.numel() == 1, \
        "output should be pseudo-'freed' in schedule, to optimize memory"
    assert isinstance(output, torch.Tensor), \
        "output == '%s'." % type(output).__name__
    assert isinstance(grad_output, (torch.Tensor, type(None))), \
        "grad_output == '%s'." % type(grad_output).__name__

    # Handle scalar output
    if grad_output is None:
        assert output.numel() == 1, "implicit grad requires scalar output."
        grad_output = torch.ones_like(
            output,
            memory_format = torch.preserve_format,
        )

    # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Lawrence McAfee's avatar
Lawrence McAfee committed
81
82
83
84
85
86
87
88
89
    Variable._execution_engine.run_backward(
        tensors = (output,),
        grad_tensors = (grad_output,),
        keep_graph = False,
        create_graph = False,
        inputs = tuple(),
        allow_unreachable=True,
        accumulate_grad=True,
    )
90
        
Jared Casper's avatar
Jared Casper committed
91

92
93
94
95
96
def forward_step(forward_step_func,
                 data_iterator,
                 model,
                 input_tensor,
                 forward_data_store,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
97
                 timers,
98
                 collect_non_loss_data=False):
99
100
101
102
103
104
    """Forward step for passed-in model.

    If first stage, input tensor is obtained from data_iterator, otherwise
    passed-in input_tensor is used.

    Returns output tensor."""
105
    args = get_args()
106

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
107
108
    if timers is not None:
        timers('forward-compute', log_level=2).start()
109
110
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
111
112
113
114
115
116

    unwrap_output_tensor = False
    if not isinstance(input_tensor, list):
        input_tensor = [input_tensor]
        unwrap_output_tensor = True

117
    unwrapped_model.set_input_tensor(input_tensor)
118
    output_tensor, loss_func = forward_step_func(data_iterator, model)
119
    if mpu.is_pipeline_last_stage():
120
121
122
123
124
125
126
127
128
        if not collect_non_loss_data:
            output_tensor = loss_func(output_tensor)
            loss, loss_reduced = output_tensor
            output_tensor = loss / get_num_microbatches()
            forward_data_store.append(loss_reduced)
        else:
            data = loss_func(output_tensor, non_loss_data=True)
            forward_data_store.append(data)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
129
130
    if timers is not None:
        timers('forward-compute').stop()
131

132
133
134
    # If T5 model (or other model with encoder and decoder)
    # and in decoder stack, then send encoder_hidden_state
    # downstream as well.
135
136
137
138
139
140
    if mpu.is_pipeline_stage_after_split() and \
            args.model_type == ModelType.encoder_and_decoder:
        return [output_tensor, input_tensor[-1]]
    if unwrap_output_tensor:
        return output_tensor
    return [output_tensor]
141
142


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
143
144
def backward_step(optimizer, input_tensor, output_tensor,
                  output_tensor_grad, timers):
145
146
147
148
149
150
151
    """Backward step through passed-in output tensor.

    If last stage, output_tensor_grad is None, otherwise gradient of loss
    with respect to stage's output tensor.

    Returns gradient of loss with respect to input tensor (None if first
    stage)."""
152
153
154
155

    # NOTE: This code currently can handle at most one skip connection. It
    # needs to be modified slightly to support arbitrary numbers of skip
    # connections.
156
157
    args = get_args()

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
158
159
    if timers is not None:
        timers('backward-compute', log_level=2).start()
160
161

    # Retain the grad on the input_tensor.
162
163
164
165
166
167
168
169
170
171
172
173
    unwrap_input_tensor_grad = False
    if not isinstance(input_tensor, list):
        input_tensor = [input_tensor]
        unwrap_input_tensor_grad = True
    for x in input_tensor:
        if x is not None:
            x.retain_grad()

    if not isinstance(output_tensor, list):
        output_tensor = [output_tensor]
    if not isinstance(output_tensor_grad, list):
        output_tensor_grad = [output_tensor_grad]
174
175

    # Backward pass.
176
177
    if output_tensor_grad[0] is None:
        output_tensor = optimizer.scale_loss(output_tensor[0])
178
    custom_backward(output_tensor[0], output_tensor_grad[0])
179
180

    # Collect the grad of the input_tensor.
181
    input_tensor_grad = [None]
182
    if input_tensor is not None:
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        input_tensor_grad = []
        for x in input_tensor:
            if x is None:
                input_tensor_grad.append(None)
            else:
                input_tensor_grad.append(x.grad)

    # Handle single skip connection if it exists (encoder_hidden_state in
    # model with encoder and decoder).
    if mpu.get_pipeline_model_parallel_world_size() > 1 and \
            mpu.is_pipeline_stage_after_split() and \
            args.model_type == ModelType.encoder_and_decoder:
        if output_tensor_grad[1] is not None:
            input_tensor_grad[-1].add_(output_tensor_grad[1])
    if unwrap_input_tensor_grad:
        input_tensor_grad = input_tensor_grad[0]
199

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
200
201
    if timers is not None:
        timers('backward-compute').stop()
202
203
204
205

    return input_tensor_grad


206
207
208
209
210
211
212
213
@contextmanager
def dummy_handler():
    try:
        yield
    finally:
        pass


214
215
216
217
218
219
def forward_backward_no_pipelining(forward_step_func,
                                   data_iterator, model,
                                   optimizer,
                                   timers,
                                   forward_only,
                                   collect_non_loss_data=False):
220
221
222
223
    """Run forward and backward passes with no pipeline parallelism
    (no inter-stage communication).

    Returns dictionary with losses."""
224
225
226
    assert len(model) == 1
    model = model[0]

227
228
229
230
    context_handler = dummy_handler
    if isinstance(model, torchDDP):
        context_handler = model.no_sync

231
    forward_data_store = []
232
233
234
    input_tensor, output_tensor_grad = None, None
    with context_handler():
        for i in range(get_num_microbatches() - 1):
235
236
            output_tensor = forward_step(forward_step_func, data_iterator,
                                         model, input_tensor, forward_data_store,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
237
                                         timers, collect_non_loss_data)
238
239
            if not forward_only:
                backward_step(optimizer, input_tensor, output_tensor,
240
                              output_tensor_grad, timers)
241
242
243

    # Run computation for last microbatch out of context handler (want to
    # synchronize gradients).
244
245
    output_tensor = forward_step(forward_step_func, data_iterator,
                                 model, input_tensor, forward_data_store,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
246
                                 timers, collect_non_loss_data)
247
    if not forward_only:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
248
249
        backward_step(optimizer, input_tensor, output_tensor,
                      output_tensor_grad, timers)
250

251
    return forward_data_store
252
253


254
255
256
257
258
259
def forward_backward_pipelining_with_interleaving(forward_step_func,
                                                  data_iterator, model,
                                                  optimizer,
                                                  timers,
                                                  forward_only, 
                                                  collect_non_loss_data=False):
260
261
262
263
    """Run interleaved 1F1B schedule (model split into model chunks), with
    communication between pipeline stages as needed.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
264
265
266

    args = get_args()

267
268
    input_tensors = [[] for _ in range(len(model))]
    output_tensors = [[] for _ in range(len(model))]
269
    forward_data_store = []
270
271
272
273
    if not forward_only:
        output_tensor_grads = [[] for _ in range(len(model))]

    pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
274
    pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
275

Vijay Korthikanti's avatar
Vijay Korthikanti committed
276
    if args.sequence_parallel:
277
278
279
280
281
        seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
    else:
        seq_length = args.seq_length
    tensor_shape = (seq_length, args.micro_batch_size, args.hidden_size)
    
282
283
284
285
286
287
288
    # Compute number of warmup and remaining microbatches.
    num_model_chunks = len(model)
    num_microbatches = get_num_microbatches() * num_model_chunks
    all_warmup_microbatches = False
    if forward_only:
        num_warmup_microbatches = num_microbatches
    else:
289
290
291
292
293
294
        # Run all forward passes and then all backward passes if number of
        # microbatches is just the number of pipeline stages.
        # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
        # all workers, followed by more microbatches after depending on
        # stage ID (more forward passes for earlier stages, later stages can
        # immediately start with 1F1B).
295
296
297
298
299
        if get_num_microbatches() == pipeline_parallel_size:
            num_warmup_microbatches = num_microbatches
            all_warmup_microbatches = True
        else:
            num_warmup_microbatches = \
300
                (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
301
302
303
304
            num_warmup_microbatches += (
                num_model_chunks - 1) * pipeline_parallel_size
            num_warmup_microbatches = min(num_warmup_microbatches,
                                          num_microbatches)
305
306
307
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

308
    def get_model_chunk_id(microbatch_id, forward):
309
        """Helper method to get the model chunk ID given the iteration number."""
310
311
        microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
        model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
312
        if not forward:
313
314
            model_chunk_id = (num_model_chunks - model_chunk_id - 1)
        return model_chunk_id
315

316
    def forward_step_helper(microbatch_id):
317
318
319
        """Helper method to run forward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        forward_step())."""
320
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
321
322
        mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

323
        # forward step
324
        if mpu.is_pipeline_first_stage():
325
326
            if len(input_tensors[model_chunk_id]) == \
                    len(output_tensors[model_chunk_id]):
327
328
                input_tensors[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id][-1]
329
330
        output_tensor = forward_step(forward_step_func,
                                     data_iterator[model_chunk_id],
331
                                     model[model_chunk_id],
332
333
                                     input_tensor, 
                                     forward_data_store,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
334
                                     timers,
335
                                     collect_non_loss_data)
336
337
        output_tensors[model_chunk_id].append(output_tensor)

338
339
340
341
342
        # if forward-only, no need to save tensors for a backward pass
        if forward_only:
            input_tensors[model_chunk_id].pop()
            output_tensors[model_chunk_id].pop()

343
344
        return output_tensor

345
    def backward_step_helper(microbatch_id):
346
347
348
        """Helper method to run backward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        backward_step())."""
349
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
350
351
352
353
354
355
356
357
358
        mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        if mpu.is_pipeline_last_stage():
            if len(output_tensor_grads[model_chunk_id]) == 0:
                output_tensor_grads[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id].pop(0)
        output_tensor = output_tensors[model_chunk_id].pop(0)
        output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
        input_tensor_grad = \
359
360
361
            backward_step(optimizer,
                          input_tensor,
                          output_tensor,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
362
363
                          output_tensor_grad,
                          timers)
364
365
366
367
368

        return input_tensor_grad

    # Run warmup forward passes.
    mpu.set_virtual_pipeline_model_parallel_rank(0)
369
    input_tensors[0].append(
370
        p2p_communication.recv_forward(tensor_shape, timers=timers))
371
372
    for k in range(num_warmup_microbatches):
        output_tensor = forward_step_helper(k)
373
374

        # Determine if tensor should be received from previous stage.
375
376
377
378
379
380
381
        next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
        recv_prev = True
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            if next_forward_model_chunk_id == 0:
                recv_prev = False
        if k == (num_microbatches - 1):
            recv_prev = False
382
383

        # Don't send tensor downstream if on last stage.
384
385
        if mpu.is_pipeline_last_stage():
            output_tensor = None
386
387
388

        # Send and receive tensors as appropriate (send tensors computed
        # in this iteration; receive tensors for next iteration).
389
390
391
392
393
394
395
        if k == (num_warmup_microbatches - 1) and not forward_only and \
                not all_warmup_microbatches:
            input_tensor_grad = None
            recv_next = True
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                recv_next = False
            input_tensor, output_tensor_grad = \
396
                p2p_communication.send_forward_backward_recv_forward_backward(
397
398
                        output_tensor, input_tensor_grad,
                        recv_prev=recv_prev, recv_next=recv_next,
399
                        tensor_shape=tensor_shape,
400
401
402
                        timers=timers)
            output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
        else:
403
            input_tensor = \
404
                p2p_communication.send_forward_recv_forward(
405
406
407
                    output_tensor, recv_prev=recv_prev,
                    tensor_shape=tensor_shape,
                    timers=timers)
408
        input_tensors[next_forward_model_chunk_id].append(input_tensor)
409
        deallocate_output_tensor(output_tensor)
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446

    # Run 1F1B in steady state.
    for k in range(num_microbatches_remaining):
        # Forward pass.
        forward_k = k + num_warmup_microbatches
        output_tensor = forward_step_helper(forward_k)

        # Backward pass.
        backward_k = k
        input_tensor_grad = backward_step_helper(backward_k)

        # Send output_tensor and input_tensor_grad, receive input_tensor
        # and output_tensor_grad.

        # Determine if current stage has anything to send in either direction,
        # otherwise set tensor to None.
        forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
        mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
        if mpu.is_pipeline_last_stage():
            output_tensor = None

        backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
        mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
        if mpu.is_pipeline_first_stage():
            input_tensor_grad = None

        # Determine if peers are sending, and where in data structure to put
        # received tensors.
        recv_prev = True
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            # First stage is ahead of last stage by (pipeline_parallel_size - 1).
            next_forward_model_chunk_id = get_model_chunk_id(
                forward_k - (pipeline_parallel_size - 1), forward=True)
            if next_forward_model_chunk_id == (num_model_chunks - 1):
                recv_prev = False
            next_forward_model_chunk_id += 1
        else:
447
448
            next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
                                                             forward=True)
449
450
451
452
453
454
455
456
457
458

        recv_next = True
        if mpu.is_pipeline_last_stage(ignore_virtual=True):
            # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
            next_backward_model_chunk_id = get_model_chunk_id(
                backward_k - (pipeline_parallel_size - 1), forward=False)
            if next_backward_model_chunk_id == 0:
                recv_next = False
            next_backward_model_chunk_id -= 1
        else:
459
460
            next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
                                                              forward=False)
461

462
463
        # If last iteration, don't receive; we already received one extra
        # before the start of the for loop.
464
465
466
467
468
        if k == (num_microbatches_remaining - 1):
            recv_prev = False

        # Communicate tensors.
        input_tensor, output_tensor_grad = \
469
            p2p_communication.send_forward_backward_recv_forward_backward(
470
471
                    output_tensor, input_tensor_grad,
                    recv_prev=recv_prev, recv_next=recv_next,
472
                    tensor_shape=tensor_shape, timers=timers)
473
        deallocate_output_tensor(output_tensor)
474

475
476
        # Put input_tensor and output_tensor_grad in data structures in the
        # right location.
477
478
479
        if recv_prev:
            input_tensors[next_forward_model_chunk_id].append(input_tensor)
        if recv_next:
480
481
            output_tensor_grads[next_backward_model_chunk_id].append(
                output_tensor_grad)
482

483
    # Run cooldown backward passes (flush out pipeline).
484
485
486
    if not forward_only:
        if all_warmup_microbatches:
            output_tensor_grads[num_model_chunks-1].append(
487
                p2p_communication.recv_backward(tensor_shape, timers=timers))
488
489
490
491
492
493
494
495
496
497
        for k in range(num_microbatches_remaining, num_microbatches):
            input_tensor_grad = backward_step_helper(k)
            next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
            recv_next = True
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                if next_backward_model_chunk_id == (num_model_chunks - 1):
                    recv_next = False
            if k == (num_microbatches - 1):
                recv_next = False
            output_tensor_grads[next_backward_model_chunk_id].append(
498
                p2p_communication.send_backward_recv_backward(
499
500
501
                    input_tensor_grad, recv_next=recv_next,
                    tensor_shape=tensor_shape,
                    timers=timers))
502

503
    return forward_data_store
504
505


506
507
508
509
510
511
512
513
514
515
516
def get_tensor_shapes(rank, model_type):
    # Determine right tensor sizes (based on position of rank with respect to split
    # rank) and model size.
    # Send two tensors if model is T5 and rank is in decoder stage:
    #     first tensor is decoder (pre-transpose),
    #     second tensor is encoder (post-transpose).
    # If model is T5 and rank is at the boundary:
    #     send one tensor (post-transpose from encoder).
    # Otherwise, send one tensor (pre-transpose).
    args = get_args()
    tensor_shapes = []
517

Vijay Korthikanti's avatar
Vijay Korthikanti committed
518
    if args.sequence_parallel:
519
        seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
520
521
522
523
    else:
        seq_length = args.seq_length

    if model_type == ModelType.encoder_and_decoder:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
524
        if args.sequence_parallel:
525
526
            decoder_seq_length = args.decoder_seq_length // mpu.get_tensor_model_parallel_world_size()
        else:
527
            decoder_seq_length = args.decoder_seq_length
528

529
        if mpu.is_pipeline_stage_before_split(rank):
530
            tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
531
        else:
532
533
            tensor_shapes.append((decoder_seq_length, args.micro_batch_size, args.hidden_size))
            tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
534
    else:
535
        tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
    return tensor_shapes


def recv_forward(tensor_shapes, timers):
    input_tensors = []
    for tensor_shape in tensor_shapes:
        if tensor_shape is None:
            input_tensors.append(None)
        else:
            input_tensors.append(p2p_communication.recv_forward(tensor_shape,
                                                                timers=timers))
    return input_tensors


def recv_backward(tensor_shapes, timers):
    output_tensor_grads = []
    for tensor_shape in tensor_shapes:
        if tensor_shape is None:
            output_tensor_grads.append(None)
        else:
            output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape,
                                                                       timers=timers))
    return output_tensor_grads


def send_forward(output_tensors, tensor_shapes, timers):
    if not isinstance(output_tensors, list):
        output_tensors = [output_tensors]
    for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
        if tensor_shape is None:
            continue
        p2p_communication.send_forward(output_tensor, tensor_shape, timers=timers)


def send_backward(input_tensor_grads, tensor_shapes, timers):
    if not isinstance(input_tensor_grads, list):
        input_tensor_grads = [input_tensor_grads]
    for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
        if tensor_shape is None:
            continue
        p2p_communication.send_backward(input_tensor_grad, tensor_shape, timers=timers)


def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
    if not isinstance(output_tensors, list):
        output_tensors = [output_tensors]
    output_tensor_grads = []
    for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
        if tensor_shape is None:
            output_tensor_grads.append(None)
            continue
        output_tensor_grad = p2p_communication.send_forward_recv_backward(
                output_tensor, tensor_shape, timers=timers)
        output_tensor_grads.append(output_tensor_grad)
    return output_tensor_grads


def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
    if not isinstance(input_tensor_grads, list):
        input_tensor_grads = [input_tensor_grads]
    input_tensors = []
    for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
        if tensor_shape is None:
            input_tensors.append(None)
            continue
        input_tensor = p2p_communication.send_backward_recv_forward(
                input_tensor_grad, tensor_shape, timers=timers)
        input_tensors.append(input_tensor)
    return input_tensors


607
608
609
610
611
612
613
def forward_backward_pipelining_without_interleaving(forward_step_func,
                                                     data_iterator,
                                                     model,
                                                     optimizer,
                                                     timers,
                                                     forward_only,
                                                     collect_non_loss_data=False):
614
615
616
617
    """Run non-interleaved 1F1B schedule, with communication between pipeline
    stages.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
618
    args = get_args()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
619
    
620
621
622
623
624
625
626
627
628
629
630
631
632
633
    assert len(model) == 1
    model = model[0]

    # Compute number of warmup microbatches.
    num_microbatches = get_num_microbatches()
    num_warmup_microbatches = \
        (mpu.get_pipeline_model_parallel_world_size() -
         mpu.get_pipeline_model_parallel_rank() - 1)
    num_warmup_microbatches = min(
        num_warmup_microbatches,
        num_microbatches)
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

634
635
636
637
638
639
640
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    model_type = unwrapped_model.model_type
    rank = mpu.get_pipeline_model_parallel_rank()
    recv_tensor_shapes = get_tensor_shapes(rank-1, model_type)
    send_tensor_shapes = get_tensor_shapes(rank, model_type)

641
642
643
644
645
646
    # Input, output tensors only need to be saved when doing backward passes
    input_tensors = None
    output_tensors = None
    if not forward_only:
        input_tensors = []
        output_tensors = []
647
    forward_data_store = []
648
649
650

    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
651
        input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
652
        output_tensor = forward_step(forward_step_func, data_iterator, model,
653
                                     input_tensor, forward_data_store,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
654
                                     timers, collect_non_loss_data)
655
        send_forward(output_tensor, send_tensor_shapes, timers=timers)
656

657
658
659
        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
660
            deallocate_output_tensor(output_tensor[0])
661
662
663
664
665

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
666
        input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
667
668
669
670
671
672

    # Run 1F1B in steady state.
    for i in range(num_microbatches_remaining):
        last_iteration = (i == (num_microbatches_remaining - 1))

        output_tensor = forward_step(forward_step_func, data_iterator, model,
673
                                     input_tensor, forward_data_store,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
674
                                     timers, collect_non_loss_data)
675
        if forward_only:
676
            send_forward(output_tensor, send_tensor_shapes, timers=timers)
677
678

            if not last_iteration:
679
                input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
680

681
        else:
682
            output_tensor_grad = \
683
684
685
                send_forward_recv_backward(output_tensor,
                                           send_tensor_shapes,
                                           timers=timers)
686

687
688
689
            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
690
            deallocate_output_tensor(output_tensor[0])
691

692
693
694
695
            # Pop input_tensor and output_tensor from the start of the list for
            # the backward pass.
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)
696
697
698

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
699
                              output_tensor_grad, timers)
700
701
702

            if last_iteration:
                input_tensor = None
703
                send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
704
            else:
705
                input_tensor = \
706
707
                    send_backward_recv_forward(
                        input_tensor_grad, recv_tensor_shapes, timers=timers)
708
709
710
711
712
713
714

    # Run cooldown backward passes.
    if not forward_only:
        for i in range(num_warmup_microbatches):
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

715
            output_tensor_grad = recv_backward(send_tensor_shapes, timers=timers)
716
717
718

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
719
                              output_tensor_grad, timers)
720

721
            send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
722

723
    return forward_data_store