schedules.py 29.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
from contextlib import contextmanager
17
import torch
18
from torch.autograd.variable import Variable
19
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
20
21

from megatron import get_args
22
from megatron import get_num_microbatches
23
24
from megatron import get_timers
from megatron import mpu
25
from megatron import p2p_communication
26
27
28
from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
29
30
from megatron.model import ModelType

31

Jared Casper's avatar
Jared Casper committed
32
33
34
35
36
def get_forward_backward_func():
    args = get_args()
    if mpu.get_pipeline_model_parallel_world_size() > 1:
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
Lawrence McAfee's avatar
Lawrence McAfee committed
37
38
39
40
41
42
43
            assert get_num_microbatches() % \
                args.pipeline_model_parallel_size == 0, \
                'number of microbatches (%d) is not divisible by pipeline-' \
                'model-parallel-size (%d) when using interleaved schedule' % (
                    get_num_microbatches(),
                    args.pipeline_model_parallel_size,
                )
Jared Casper's avatar
Jared Casper committed
44
45
46
47
48
49
        else:
            forward_backward_func = forward_backward_pipelining_without_interleaving
    else:
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func

50
51
def deallocate_output_tensor(out):
    '''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
52
53
54
55
56

    This method should be called right after the output tensor has been
    sent to the next pipeline stage. At this point, the output tensor is
    only useful for its '.grad_fn' field, and not its '.data'.
    '''
Lawrence McAfee's avatar
Lawrence McAfee committed
57
58
    if out is None:
        return
59
60
61
62
63
64
65
66
67
    assert isinstance(out, torch.Tensor), \
        "expected Tensor, found %s." % type(out).__name__
    assert out._base is None, \
        "counter-productive to free a view of another tensor."
    out.data = torch.empty(
        (1,),
        device = out.device,
        dtype = out.dtype,
    )
68
        
69
def custom_backward(output, grad_output):
70
71
    '''Directly call C++ autograd engine.

72
    To make the 'deallocate_output_tensor' (above) optimization work, the C++
73
74
75
76
    autograd engine must be called directly, bypassing Pytorch's
    torch.autograd.backward. Pytorch's 'backward' checks that the output and
    grad have the same shape, while C++'s 'backward' does not.
    '''
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

    assert output.numel() == 1, \
        "output should be pseudo-'freed' in schedule, to optimize memory"
    assert isinstance(output, torch.Tensor), \
        "output == '%s'." % type(output).__name__
    assert isinstance(grad_output, (torch.Tensor, type(None))), \
        "grad_output == '%s'." % type(grad_output).__name__

    # Handle scalar output
    if grad_output is None:
        assert output.numel() == 1, "implicit grad requires scalar output."
        grad_output = torch.ones_like(
            output,
            memory_format = torch.preserve_format,
        )

    # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Lawrence McAfee's avatar
Lawrence McAfee committed
94
95
96
97
98
99
100
101
102
    Variable._execution_engine.run_backward(
        tensors = (output,),
        grad_tensors = (grad_output,),
        keep_graph = False,
        create_graph = False,
        inputs = tuple(),
        allow_unreachable=True,
        accumulate_grad=True,
    )
103
        
Jared Casper's avatar
Jared Casper committed
104

105
106
107
108
109
def forward_step(forward_step_func,
                 data_iterator,
                 model,
                 input_tensor,
                 forward_data_store,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
110
                 timers,
111
                 collect_non_loss_data=False):
112
113
114
115
116
117
    """Forward step for passed-in model.

    If first stage, input tensor is obtained from data_iterator, otherwise
    passed-in input_tensor is used.

    Returns output tensor."""
118
    args = get_args()
119

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
120
121
    if timers is not None:
        timers('forward-compute', log_level=2).start()
122
123
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
124
125
126
127
128
129

    unwrap_output_tensor = False
    if not isinstance(input_tensor, list):
        input_tensor = [input_tensor]
        unwrap_output_tensor = True

130
    unwrapped_model.set_input_tensor(input_tensor)
131
    output_tensor, loss_func = forward_step_func(data_iterator, model)
132
    if mpu.is_pipeline_last_stage():
133
134
135
136
137
138
139
140
141
        if not collect_non_loss_data:
            output_tensor = loss_func(output_tensor)
            loss, loss_reduced = output_tensor
            output_tensor = loss / get_num_microbatches()
            forward_data_store.append(loss_reduced)
        else:
            data = loss_func(output_tensor, non_loss_data=True)
            forward_data_store.append(data)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
142
143
    if timers is not None:
        timers('forward-compute').stop()
144

145
146
147
    # If T5 model (or other model with encoder and decoder)
    # and in decoder stack, then send encoder_hidden_state
    # downstream as well.
148
149
150
151
152
153
    if mpu.is_pipeline_stage_after_split() and \
            args.model_type == ModelType.encoder_and_decoder:
        return [output_tensor, input_tensor[-1]]
    if unwrap_output_tensor:
        return output_tensor
    return [output_tensor]
154
155


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
156
157
def backward_step(optimizer, input_tensor, output_tensor,
                  output_tensor_grad, timers):
158
159
160
161
162
163
164
    """Backward step through passed-in output tensor.

    If last stage, output_tensor_grad is None, otherwise gradient of loss
    with respect to stage's output tensor.

    Returns gradient of loss with respect to input tensor (None if first
    stage)."""
165
166
167
168

    # NOTE: This code currently can handle at most one skip connection. It
    # needs to be modified slightly to support arbitrary numbers of skip
    # connections.
169
170
    args = get_args()

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
171
172
    if timers is not None:
        timers('backward-compute', log_level=2).start()
173
174

    # Retain the grad on the input_tensor.
175
176
177
178
179
180
181
182
183
184
185
186
    unwrap_input_tensor_grad = False
    if not isinstance(input_tensor, list):
        input_tensor = [input_tensor]
        unwrap_input_tensor_grad = True
    for x in input_tensor:
        if x is not None:
            x.retain_grad()

    if not isinstance(output_tensor, list):
        output_tensor = [output_tensor]
    if not isinstance(output_tensor_grad, list):
        output_tensor_grad = [output_tensor_grad]
187
188

    # Backward pass.
189
190
    if output_tensor_grad[0] is None:
        output_tensor = optimizer.scale_loss(output_tensor[0])
191
    custom_backward(output_tensor[0], output_tensor_grad[0])
192
193

    # Collect the grad of the input_tensor.
194
    input_tensor_grad = [None]
195
    if input_tensor is not None:
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        input_tensor_grad = []
        for x in input_tensor:
            if x is None:
                input_tensor_grad.append(None)
            else:
                input_tensor_grad.append(x.grad)

    # Handle single skip connection if it exists (encoder_hidden_state in
    # model with encoder and decoder).
    if mpu.get_pipeline_model_parallel_world_size() > 1 and \
            mpu.is_pipeline_stage_after_split() and \
            args.model_type == ModelType.encoder_and_decoder:
        if output_tensor_grad[1] is not None:
            input_tensor_grad[-1].add_(output_tensor_grad[1])
    if unwrap_input_tensor_grad:
        input_tensor_grad = input_tensor_grad[0]
212

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
213
214
    if timers is not None:
        timers('backward-compute').stop()
215
216
217
218

    return input_tensor_grad


219
220
221
222
223
224
225
226
@contextmanager
def dummy_handler():
    try:
        yield
    finally:
        pass


227
228
229
230
231
232
def forward_backward_no_pipelining(forward_step_func,
                                   data_iterator, model,
                                   optimizer,
                                   timers,
                                   forward_only,
                                   collect_non_loss_data=False):
233
234
235
236
    """Run forward and backward passes with no pipeline parallelism
    (no inter-stage communication).

    Returns dictionary with losses."""
237
238
239
    assert len(model) == 1
    model = model[0]

240
241
242
243
    context_handler = dummy_handler
    if isinstance(model, torchDDP):
        context_handler = model.no_sync

244
    forward_data_store = []
245
246
247
    input_tensor, output_tensor_grad = None, None
    with context_handler():
        for i in range(get_num_microbatches() - 1):
248
249
            output_tensor = forward_step(forward_step_func, data_iterator,
                                         model, input_tensor, forward_data_store,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
250
                                         timers, collect_non_loss_data)
251
252
            if not forward_only:
                backward_step(optimizer, input_tensor, output_tensor,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
253
                              timers, output_tensor_grad)
254
255
256

    # Run computation for last microbatch out of context handler (want to
    # synchronize gradients).
257
258
    output_tensor = forward_step(forward_step_func, data_iterator,
                                 model, input_tensor, forward_data_store,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
259
                                 timers, collect_non_loss_data)
260
    if not forward_only:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
261
262
        backward_step(optimizer, input_tensor, output_tensor,
                      output_tensor_grad, timers)
263

264
    return forward_data_store
265
266


267
268
269
270
271
272
def forward_backward_pipelining_with_interleaving(forward_step_func,
                                                  data_iterator, model,
                                                  optimizer,
                                                  timers,
                                                  forward_only, 
                                                  collect_non_loss_data=False):
273
274
275
276
    """Run interleaved 1F1B schedule (model split into model chunks), with
    communication between pipeline stages as needed.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
277
278
279

    args = get_args()

280
281
    input_tensors = [[] for _ in range(len(model))]
    output_tensors = [[] for _ in range(len(model))]
282
    forward_data_store = []
283
284
285
286
    if not forward_only:
        output_tensor_grads = [[] for _ in range(len(model))]

    pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
287
    pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
288

Vijay Korthikanti's avatar
Vijay Korthikanti committed
289
    if args.sequence_parallel:
290
291
292
293
294
        seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
    else:
        seq_length = args.seq_length
    tensor_shape = (seq_length, args.micro_batch_size, args.hidden_size)
    
295
296
297
298
299
300
301
    # Compute number of warmup and remaining microbatches.
    num_model_chunks = len(model)
    num_microbatches = get_num_microbatches() * num_model_chunks
    all_warmup_microbatches = False
    if forward_only:
        num_warmup_microbatches = num_microbatches
    else:
302
303
304
305
306
307
        # Run all forward passes and then all backward passes if number of
        # microbatches is just the number of pipeline stages.
        # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
        # all workers, followed by more microbatches after depending on
        # stage ID (more forward passes for earlier stages, later stages can
        # immediately start with 1F1B).
308
309
310
311
312
        if get_num_microbatches() == pipeline_parallel_size:
            num_warmup_microbatches = num_microbatches
            all_warmup_microbatches = True
        else:
            num_warmup_microbatches = \
313
                (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
314
315
316
317
            num_warmup_microbatches += (
                num_model_chunks - 1) * pipeline_parallel_size
            num_warmup_microbatches = min(num_warmup_microbatches,
                                          num_microbatches)
318
319
320
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

321
    def get_model_chunk_id(microbatch_id, forward):
322
        """Helper method to get the model chunk ID given the iteration number."""
323
324
        microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
        model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
325
        if not forward:
326
327
            model_chunk_id = (num_model_chunks - model_chunk_id - 1)
        return model_chunk_id
328

329
    def forward_step_helper(microbatch_id):
330
331
332
        """Helper method to run forward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        forward_step())."""
333
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
334
335
        mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

336
        # forward step
337
        if mpu.is_pipeline_first_stage():
338
339
            if len(input_tensors[model_chunk_id]) == \
                    len(output_tensors[model_chunk_id]):
340
341
                input_tensors[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id][-1]
342
343
        output_tensor = forward_step(forward_step_func,
                                     data_iterator[model_chunk_id],
344
                                     model[model_chunk_id],
345
346
                                     input_tensor, 
                                     forward_data_store,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
347
                                     timers,
348
                                     collect_non_loss_data)
349
350
        output_tensors[model_chunk_id].append(output_tensor)

351
352
353
354
355
        # if forward-only, no need to save tensors for a backward pass
        if forward_only:
            input_tensors[model_chunk_id].pop()
            output_tensors[model_chunk_id].pop()

356
357
        return output_tensor

358
    def backward_step_helper(microbatch_id):
359
360
361
        """Helper method to run backward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        backward_step())."""
362
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
363
364
365
366
367
368
369
370
371
        mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        if mpu.is_pipeline_last_stage():
            if len(output_tensor_grads[model_chunk_id]) == 0:
                output_tensor_grads[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id].pop(0)
        output_tensor = output_tensors[model_chunk_id].pop(0)
        output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
        input_tensor_grad = \
372
373
374
            backward_step(optimizer,
                          input_tensor,
                          output_tensor,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
375
376
                          output_tensor_grad,
                          timers)
377
378
379
380
381

        return input_tensor_grad

    # Run warmup forward passes.
    mpu.set_virtual_pipeline_model_parallel_rank(0)
382
    input_tensors[0].append(
383
        p2p_communication.recv_forward(tensor_shape, timers=timers))
384
385
    for k in range(num_warmup_microbatches):
        output_tensor = forward_step_helper(k)
386
387

        # Determine if tensor should be received from previous stage.
388
389
390
391
392
393
394
        next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
        recv_prev = True
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            if next_forward_model_chunk_id == 0:
                recv_prev = False
        if k == (num_microbatches - 1):
            recv_prev = False
395
396

        # Don't send tensor downstream if on last stage.
397
398
        if mpu.is_pipeline_last_stage():
            output_tensor = None
399
400
401

        # Send and receive tensors as appropriate (send tensors computed
        # in this iteration; receive tensors for next iteration).
402
403
404
405
406
407
408
        if k == (num_warmup_microbatches - 1) and not forward_only and \
                not all_warmup_microbatches:
            input_tensor_grad = None
            recv_next = True
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                recv_next = False
            input_tensor, output_tensor_grad = \
409
                p2p_communication.send_forward_backward_recv_forward_backward(
410
411
                        output_tensor, input_tensor_grad,
                        recv_prev=recv_prev, recv_next=recv_next,
412
                        tensor_shape=tensor_shape,
413
414
415
                        timers=timers)
            output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
        else:
416
            input_tensor = \
417
                p2p_communication.send_forward_recv_forward(
418
419
420
                    output_tensor, recv_prev=recv_prev,
                    tensor_shape=tensor_shape,
                    timers=timers)
421
        input_tensors[next_forward_model_chunk_id].append(input_tensor)
422
        deallocate_output_tensor(output_tensor)
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459

    # Run 1F1B in steady state.
    for k in range(num_microbatches_remaining):
        # Forward pass.
        forward_k = k + num_warmup_microbatches
        output_tensor = forward_step_helper(forward_k)

        # Backward pass.
        backward_k = k
        input_tensor_grad = backward_step_helper(backward_k)

        # Send output_tensor and input_tensor_grad, receive input_tensor
        # and output_tensor_grad.

        # Determine if current stage has anything to send in either direction,
        # otherwise set tensor to None.
        forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
        mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
        if mpu.is_pipeline_last_stage():
            output_tensor = None

        backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
        mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
        if mpu.is_pipeline_first_stage():
            input_tensor_grad = None

        # Determine if peers are sending, and where in data structure to put
        # received tensors.
        recv_prev = True
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            # First stage is ahead of last stage by (pipeline_parallel_size - 1).
            next_forward_model_chunk_id = get_model_chunk_id(
                forward_k - (pipeline_parallel_size - 1), forward=True)
            if next_forward_model_chunk_id == (num_model_chunks - 1):
                recv_prev = False
            next_forward_model_chunk_id += 1
        else:
460
461
            next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
                                                             forward=True)
462
463
464
465
466
467
468
469
470
471

        recv_next = True
        if mpu.is_pipeline_last_stage(ignore_virtual=True):
            # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
            next_backward_model_chunk_id = get_model_chunk_id(
                backward_k - (pipeline_parallel_size - 1), forward=False)
            if next_backward_model_chunk_id == 0:
                recv_next = False
            next_backward_model_chunk_id -= 1
        else:
472
473
            next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
                                                              forward=False)
474

475
476
        # If last iteration, don't receive; we already received one extra
        # before the start of the for loop.
477
478
479
480
481
        if k == (num_microbatches_remaining - 1):
            recv_prev = False

        # Communicate tensors.
        input_tensor, output_tensor_grad = \
482
            p2p_communication.send_forward_backward_recv_forward_backward(
483
484
                    output_tensor, input_tensor_grad,
                    recv_prev=recv_prev, recv_next=recv_next,
485
                    tensor_shape=tensor_shape, timers=timers)
486
        deallocate_output_tensor(output_tensor)
487

488
489
        # Put input_tensor and output_tensor_grad in data structures in the
        # right location.
490
491
492
        if recv_prev:
            input_tensors[next_forward_model_chunk_id].append(input_tensor)
        if recv_next:
493
494
            output_tensor_grads[next_backward_model_chunk_id].append(
                output_tensor_grad)
495

496
    # Run cooldown backward passes (flush out pipeline).
497
498
499
    if not forward_only:
        if all_warmup_microbatches:
            output_tensor_grads[num_model_chunks-1].append(
500
                p2p_communication.recv_backward(tensor_shape, timers=timers))
501
502
503
504
505
506
507
508
509
510
        for k in range(num_microbatches_remaining, num_microbatches):
            input_tensor_grad = backward_step_helper(k)
            next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
            recv_next = True
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                if next_backward_model_chunk_id == (num_model_chunks - 1):
                    recv_next = False
            if k == (num_microbatches - 1):
                recv_next = False
            output_tensor_grads[next_backward_model_chunk_id].append(
511
                p2p_communication.send_backward_recv_backward(
512
513
514
                    input_tensor_grad, recv_next=recv_next,
                    tensor_shape=tensor_shape,
                    timers=timers))
515

516
    return forward_data_store
517
518


519
520
521
522
523
524
525
526
527
528
529
def get_tensor_shapes(rank, model_type):
    # Determine right tensor sizes (based on position of rank with respect to split
    # rank) and model size.
    # Send two tensors if model is T5 and rank is in decoder stage:
    #     first tensor is decoder (pre-transpose),
    #     second tensor is encoder (post-transpose).
    # If model is T5 and rank is at the boundary:
    #     send one tensor (post-transpose from encoder).
    # Otherwise, send one tensor (pre-transpose).
    args = get_args()
    tensor_shapes = []
530

Vijay Korthikanti's avatar
Vijay Korthikanti committed
531
    if args.sequence_parallel:
532
        seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
533
534
535
536
    else:
        seq_length = args.seq_length

    if model_type == ModelType.encoder_and_decoder:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
537
        if args.sequence_parallel:
538
539
            decoder_seq_length = args.decoder_seq_length // mpu.get_tensor_model_parallel_world_size()
        else:
540
            decoder_seq_length = args.decoder_seq_length
541

542
        if mpu.is_pipeline_stage_before_split(rank):
543
            tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
544
        else:
545
546
            tensor_shapes.append((decoder_seq_length, args.micro_batch_size, args.hidden_size))
            tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
547
    else:
548
        tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
    return tensor_shapes


def recv_forward(tensor_shapes, timers):
    input_tensors = []
    for tensor_shape in tensor_shapes:
        if tensor_shape is None:
            input_tensors.append(None)
        else:
            input_tensors.append(p2p_communication.recv_forward(tensor_shape,
                                                                timers=timers))
    return input_tensors


def recv_backward(tensor_shapes, timers):
    output_tensor_grads = []
    for tensor_shape in tensor_shapes:
        if tensor_shape is None:
            output_tensor_grads.append(None)
        else:
            output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape,
                                                                       timers=timers))
    return output_tensor_grads


def send_forward(output_tensors, tensor_shapes, timers):
    if not isinstance(output_tensors, list):
        output_tensors = [output_tensors]
    for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
        if tensor_shape is None:
            continue
        p2p_communication.send_forward(output_tensor, tensor_shape, timers=timers)


def send_backward(input_tensor_grads, tensor_shapes, timers):
    if not isinstance(input_tensor_grads, list):
        input_tensor_grads = [input_tensor_grads]
    for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
        if tensor_shape is None:
            continue
        p2p_communication.send_backward(input_tensor_grad, tensor_shape, timers=timers)


def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
    if not isinstance(output_tensors, list):
        output_tensors = [output_tensors]
    output_tensor_grads = []
    for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
        if tensor_shape is None:
            output_tensor_grads.append(None)
            continue
        output_tensor_grad = p2p_communication.send_forward_recv_backward(
                output_tensor, tensor_shape, timers=timers)
        output_tensor_grads.append(output_tensor_grad)
    return output_tensor_grads


def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
    if not isinstance(input_tensor_grads, list):
        input_tensor_grads = [input_tensor_grads]
    input_tensors = []
    for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
        if tensor_shape is None:
            input_tensors.append(None)
            continue
        input_tensor = p2p_communication.send_backward_recv_forward(
                input_tensor_grad, tensor_shape, timers=timers)
        input_tensors.append(input_tensor)
    return input_tensors


620
621
622
623
624
625
626
def forward_backward_pipelining_without_interleaving(forward_step_func,
                                                     data_iterator,
                                                     model,
                                                     optimizer,
                                                     timers,
                                                     forward_only,
                                                     collect_non_loss_data=False):
627
628
629
630
    """Run non-interleaved 1F1B schedule, with communication between pipeline
    stages.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
631
    args = get_args()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
632
    
633
634
635
636
637
638
639
640
641
642
643
644
645
646
    assert len(model) == 1
    model = model[0]

    # Compute number of warmup microbatches.
    num_microbatches = get_num_microbatches()
    num_warmup_microbatches = \
        (mpu.get_pipeline_model_parallel_world_size() -
         mpu.get_pipeline_model_parallel_rank() - 1)
    num_warmup_microbatches = min(
        num_warmup_microbatches,
        num_microbatches)
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

647
648
649
650
651
652
653
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    model_type = unwrapped_model.model_type
    rank = mpu.get_pipeline_model_parallel_rank()
    recv_tensor_shapes = get_tensor_shapes(rank-1, model_type)
    send_tensor_shapes = get_tensor_shapes(rank, model_type)

654
655
656
657
658
659
    # Input, output tensors only need to be saved when doing backward passes
    input_tensors = None
    output_tensors = None
    if not forward_only:
        input_tensors = []
        output_tensors = []
660
    forward_data_store = []
661
662
663

    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
664
        input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
665
        output_tensor = forward_step(forward_step_func, data_iterator, model,
666
                                     input_tensor, forward_data_store,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
667
                                     timers, collect_non_loss_data)
668
        send_forward(output_tensor, send_tensor_shapes, timers=timers)
669

670
671
672
        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
673
            deallocate_output_tensor(output_tensor[0])
674
675
676
677
678

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
679
        input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
680
681
682
683
684
685

    # Run 1F1B in steady state.
    for i in range(num_microbatches_remaining):
        last_iteration = (i == (num_microbatches_remaining - 1))

        output_tensor = forward_step(forward_step_func, data_iterator, model,
686
                                     input_tensor, forward_data_store,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
687
                                     timers, collect_non_loss_data)
688
        if forward_only:
689
            send_forward(output_tensor, send_tensor_shapes, timers=timers)
690
691

            if not last_iteration:
692
                input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
693

694
        else:
695
            output_tensor_grad = \
696
697
698
                send_forward_recv_backward(output_tensor,
                                           send_tensor_shapes,
                                           timers=timers)
699

700
701
702
            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
703
            deallocate_output_tensor(output_tensor[0])
704

705
706
707
708
            # Pop input_tensor and output_tensor from the start of the list for
            # the backward pass.
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)
709
710
711

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
712
                              output_tensor_grad, timers)
713
714
715

            if last_iteration:
                input_tensor = None
716
                send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
717
            else:
718
                input_tensor = \
719
720
                    send_backward_recv_forward(
                        input_tensor_grad, recv_tensor_shapes, timers=timers)
721
722
723
724
725
726
727

    # Run cooldown backward passes.
    if not forward_only:
        for i in range(num_warmup_microbatches):
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

728
            output_tensor_grad = recv_backward(send_tensor_shapes, timers=timers)
729
730
731

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
732
                              output_tensor_grad, timers)
733

734
            send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
735

736
    return forward_data_store