schedules.py 49.6 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2

liangjing's avatar
v1  
liangjing committed
3
4
import contextlib
from typing import Callable, Iterator, List, Optional, Union
5

6
import torch
7
from torch.autograd.variable import Variable
8
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
9

liangjing's avatar
v1  
liangjing committed
10
from megatron import core
11
12
from megatron.core import parallel_state
from megatron.core.enums import ModelType
liangjing's avatar
v1  
liangjing committed
13
14
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.utils import get_attr_wrapped_model, get_model_config, get_model_type
15

16
17
# Types
Shape = Union[List[int], torch.Size]
18

liangjing's avatar
v1  
liangjing committed
19

Jared Casper's avatar
Jared Casper committed
20
def get_forward_backward_func():
21
22
23
24
25
26
27
28
    """Retrieves the appropriate forward_backward function given the
    configuration of parallel_state.

    Returns a function that will perform all of the forward and
    backward passes of the model given the pipeline model parallel
    world size and virtual pipeline model parallel world size in the
    global parallel_state.

liangjing's avatar
v1  
liangjing committed
29
30
31
32
    Note that if using sequence parallelism, the sequence length component of
    the tensor shape is updated to original_sequence_length /
    tensor_model_parallel_world_size.

33
34
35
36
37
38
39
40
    The function returned takes the following arguments:

    forward_step_func (required): A function that takes a data
        iterator and a model as its arguments and return the model's
        forward output and the loss function. The loss function should
        take one torch.Tensor and return a torch.Tensor of loss and a
        dictionary of string -> torch.Tensor.

liangjing's avatar
v1  
liangjing committed
41
42
43
44
45
46
47
        A third argument, checkpoint_activations_microbatch, indicates
        that the activations for this microbatch should be
        checkpointed. A None value for this argument indicates that
        the default from the configuration should be used. This is
        used when the
        num_microbatches_with_partial_activation_checkpoints is used.

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        For example:

        def loss_func(loss_mask, output_tensor):
            losses = output_tensor.float()
            loss_mask = loss_mask.view(-1).float()
            loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

            # Reduce loss for logging.
            averaged_loss = average_losses_across_data_parallel_group([loss])

            return loss, {'lm loss': averaged_loss[0]}

        def forward_step(data_iterator, model):
            data, loss_mask = next(data_iterator)
            output = model(data)
            return output, partial(loss_func, loss_mask)


        forward_backward_func(forward_step_func=forward_step, ...)


    data_iterator (required): an iterator over the data, will be
liangjing's avatar
v1  
liangjing committed
70
71
        passed as is to forward_step_func. Expected to be a list of
        iterators in the case of interleaved pipeline parallelism.
72

liangjing's avatar
v1  
liangjing committed
73
74
    model (required): the actual model. Expected to be a list of modules in the case of interleaved
        pipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule.
75
76
77
78

    num_microbatches (int, required):
        The number of microbatches to go through

liangjing's avatar
v1  
liangjing committed
79
80
81
82
    seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack
        transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths
        in the config is True. Otherwise, each microbatch in the current global batch size must use
        this sequence length.
83

liangjing's avatar
v1  
liangjing committed
84
    micro_batch_size (int, required): The number of sequences in a microbatch.
85

liangjing's avatar
v1  
liangjing committed
86
87
    decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stack
        transformer. This is ignored for a single-stack transformer.
88

liangjing's avatar
v1  
liangjing committed
89
    forward_only (optional, default = False): Perform only the forward step
90

liangjing's avatar
v1  
liangjing committed
91
    collect_non_loss_data (optional, bool, default=False): TODO
Abhinav Khattar's avatar
Abhinav Khattar committed
92

93
94
95
96
    """
    pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
    if pipeline_model_parallel_size > 1:
        if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
Jared Casper's avatar
Jared Casper committed
97
98
99
100
101
102
103
            forward_backward_func = forward_backward_pipelining_with_interleaving
        else:
            forward_backward_func = forward_backward_pipelining_without_interleaving
    else:
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func

liangjing's avatar
v1  
liangjing committed
104
105

def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
106
    '''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
107
108
109
110
111

    This method should be called right after the output tensor has been
    sent to the next pipeline stage. At this point, the output tensor is
    only useful for its '.grad_fn' field, and not its '.data'.
    '''
liangjing's avatar
v1  
liangjing committed
112
    if (out is None) or (not deallocate_pipeline_outputs):
Lawrence McAfee's avatar
Lawrence McAfee committed
113
        return
liangjing's avatar
v1  
liangjing committed
114
115
116
117
    assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
    assert out._base is None, "counter-productive to free a view of another tensor."
    out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)

118

119
def custom_backward(output, grad_output):
120
121
    '''Directly call C++ autograd engine.

122
    To make the 'deallocate_output_tensor' (above) optimization work, the C++
123
124
125
126
    autograd engine must be called directly, bypassing Pytorch's
    torch.autograd.backward. Pytorch's 'backward' checks that the output and
    grad have the same shape, while C++'s 'backward' does not.
    '''
127

liangjing's avatar
v1  
liangjing committed
128
129
130
    assert output.numel() == 1, "output should be pseudo-'freed' in schedule, to optimize memory"
    assert isinstance(output, torch.Tensor), "output == '%s'." % type(output).__name__
    assert isinstance(grad_output, (torch.Tensor, type(None))), (
131
        "grad_output == '%s'." % type(grad_output).__name__
liangjing's avatar
v1  
liangjing committed
132
    )
133
134
135
136

    # Handle scalar output
    if grad_output is None:
        assert output.numel() == 1, "implicit grad requires scalar output."
liangjing's avatar
v1  
liangjing committed
137
        grad_output = torch.ones_like(output, memory_format=torch.preserve_format,)
138
139

    # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Lawrence McAfee's avatar
Lawrence McAfee committed
140
    Variable._execution_engine.run_backward(
liangjing's avatar
v1  
liangjing committed
141
142
143
144
145
        tensors=(output,),
        grad_tensors=(grad_output,),
        keep_graph=False,
        create_graph=False,
        inputs=tuple(),
Lawrence McAfee's avatar
Lawrence McAfee committed
146
147
148
        allow_unreachable=True,
        accumulate_grad=True,
    )
149
150


liangjing's avatar
v1  
liangjing committed
151
152
153
154
155
156
157
158
159
160
161
def forward_step(
    forward_step_func,
    data_iterator,
    model,
    num_microbatches,
    input_tensor,
    forward_data_store,
    config,
    collect_non_loss_data=False,
    checkpoint_activations_microbatch=None,
):
162
163
164
165
166
167
    """Forward step for passed-in model.

    If first stage, input tensor is obtained from data_iterator, otherwise
    passed-in input_tensor is used.

    Returns output tensor."""
liangjing's avatar
v1  
liangjing committed
168
169
    if config.timers is not None:
        config.timers('forward-compute', log_level=2).start()
170
171
172
173
174
175

    unwrap_output_tensor = False
    if not isinstance(input_tensor, list):
        input_tensor = [input_tensor]
        unwrap_output_tensor = True

176
177
178
    set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
    set_input_tensor(input_tensor)

liangjing's avatar
v1  
liangjing committed
179
180
181
182
    if config.enable_autocast:
        context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
    else:
        context_manager = contextlib.nullcontext()
183
    with context_manager:
liangjing's avatar
v1  
liangjing committed
184
185
186
187
188
189
        if checkpoint_activations_microbatch is None:
            output_tensor, loss_func = forward_step_func(data_iterator, model)
        else:
            output_tensor, loss_func = forward_step_func(
                data_iterator, model, checkpoint_activations_microbatch
            )
190
191

    if parallel_state.is_pipeline_last_stage():
192
193
194
        if not collect_non_loss_data:
            output_tensor = loss_func(output_tensor)
            loss, loss_reduced = output_tensor
195
            output_tensor = loss / num_microbatches
196
197
198
199
200
            forward_data_store.append(loss_reduced)
        else:
            data = loss_func(output_tensor, non_loss_data=True)
            forward_data_store.append(data)

liangjing's avatar
v1  
liangjing committed
201
202
    if config.timers is not None:
        config.timers('forward-compute').stop()
203

204
205
206
    # If T5 model (or other model with encoder and decoder)
    # and in decoder stack, then send encoder_hidden_state
    # downstream as well.
207
    model_type = get_model_type(model)
liangjing's avatar
v1  
liangjing committed
208
209
210
211
    if (
        parallel_state.is_pipeline_stage_after_split()
        and model_type == ModelType.encoder_and_decoder
    ):
212
213
214
215
        return [output_tensor, input_tensor[-1]]
    if unwrap_output_tensor:
        return output_tensor
    return [output_tensor]
216
217


liangjing's avatar
v1  
liangjing committed
218
def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config):
219
220
221
222
223
224
225
    """Backward step through passed-in output tensor.

    If last stage, output_tensor_grad is None, otherwise gradient of loss
    with respect to stage's output tensor.

    Returns gradient of loss with respect to input tensor (None if first
    stage)."""
226
227
228
229

    # NOTE: This code currently can handle at most one skip connection. It
    # needs to be modified slightly to support arbitrary numbers of skip
    # connections.
230

liangjing's avatar
v1  
liangjing committed
231
232
    if config.timers is not None:
        config.timers('backward-compute', log_level=2).start()
233
234

    # Retain the grad on the input_tensor.
235
236
237
238
239
240
241
242
243
244
245
246
    unwrap_input_tensor_grad = False
    if not isinstance(input_tensor, list):
        input_tensor = [input_tensor]
        unwrap_input_tensor_grad = True
    for x in input_tensor:
        if x is not None:
            x.retain_grad()

    if not isinstance(output_tensor, list):
        output_tensor = [output_tensor]
    if not isinstance(output_tensor_grad, list):
        output_tensor_grad = [output_tensor_grad]
247
248

    # Backward pass.
liangjing's avatar
v1  
liangjing committed
249
250
251
252
253
254
255
    if output_tensor_grad[0] is None and config.grad_scale_func is not None:
        output_tensor[0] = config.grad_scale_func(output_tensor[0])

    if config.deallocate_pipeline_outputs:
        custom_backward(output_tensor[0], output_tensor_grad[0])
    else:
        torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
256
257

    # Collect the grad of the input_tensor.
258
    input_tensor_grad = [None]
259
    if input_tensor is not None:
260
261
262
263
264
265
266
267
268
        input_tensor_grad = []
        for x in input_tensor:
            if x is None:
                input_tensor_grad.append(None)
            else:
                input_tensor_grad.append(x.grad)

    # Handle single skip connection if it exists (encoder_hidden_state in
    # model with encoder and decoder).
liangjing's avatar
v1  
liangjing committed
269
270
271
272
273
    if (
        parallel_state.get_pipeline_model_parallel_world_size() > 1
        and parallel_state.is_pipeline_stage_after_split()
        and model_type == ModelType.encoder_and_decoder
    ):
274
275
276
277
        if output_tensor_grad[1] is not None:
            input_tensor_grad[-1].add_(output_tensor_grad[1])
    if unwrap_input_tensor_grad:
        input_tensor_grad = input_tensor_grad[0]
278

liangjing's avatar
v1  
liangjing committed
279
280
    if config.timers is not None:
        config.timers('backward-compute').stop()
281
282
283
284

    return input_tensor_grad


liangjing's avatar
v1  
liangjing committed
285
286
287
288
289
290
291
292
293
294
295
296
def forward_backward_no_pipelining(
    *,
    forward_step_func,
    data_iterator: Union[Iterator, List[Iterator]],
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    num_microbatches: int,
    seq_length: int,  # unused
    micro_batch_size: int,  # unused
    decoder_seq_length: int = None,  # unused
    forward_only: bool = False,
    collect_non_loss_data: bool = False,
):
297
298
299
    """Run forward and backward passes with no pipeline parallelism
    (no inter-stage communication).

300
301
302
303
304
    Returns dictionary with losses.


    See get_forward_backward_func() for argument details
    """
305

liangjing's avatar
v1  
liangjing committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
    if isinstance(model, list):
        assert len(model) == 1, "non-pipeline-parallel schedule does not support model chunking"
        model = model[0]
    if isinstance(data_iterator, list):
        assert (
            len(data_iterator) == 1
        ), "non-pipeline-parallel schedule does not support model chunking"
        data_iterator = data_iterator[0]

    config = get_model_config(model)

    no_sync_func = config.no_sync_func
    if no_sync_func is None and isinstance(model, torchDDP):
        no_sync_func = model.no_sync
    if no_sync_func is None:
        no_sync_func = contextlib.nullcontext
322

323
324
    model_type = get_model_type(model)

325
    forward_data_store = []
326
    input_tensor, output_tensor_grad = None, None
liangjing's avatar
v1  
liangjing committed
327
    with no_sync_func():
328
        for i in range(num_microbatches - 1):
liangjing's avatar
v1  
liangjing committed
329
330
331
332
333
334
335
336
337
338
            output_tensor = forward_step(
                forward_step_func,
                data_iterator,
                model,
                num_microbatches,
                input_tensor,
                forward_data_store,
                config,
                collect_non_loss_data,
            )
339
            if not forward_only:
liangjing's avatar
v1  
liangjing committed
340
                backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
341
342
343

    # Run computation for last microbatch out of context handler (want to
    # synchronize gradients).
liangjing's avatar
v1  
liangjing committed
344
345
346
347
348
349
350
351
352
353
    output_tensor = forward_step(
        forward_step_func,
        data_iterator,
        model,
        num_microbatches,
        input_tensor,
        forward_data_store,
        config,
        collect_non_loss_data,
    )
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
354

355
    if not forward_only:
liangjing's avatar
v1  
liangjing committed
356
        backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
357

358
    return forward_data_store
359
360


liangjing's avatar
v1  
liangjing committed
361
362
363
364
365
366
367
368
369
370
371
372
def forward_backward_pipelining_with_interleaving(
    *,
    forward_step_func,
    data_iterator: Union[Iterator, List[Iterator]],
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    num_microbatches: int,
    seq_length: int,
    micro_batch_size: int,
    decoder_seq_length: int = None,
    forward_only: bool = False,
    collect_non_loss_data: bool = False,
):
373
374
375
376
    """Run interleaved 1F1B schedule (model split into model chunks), with
    communication between pipeline stages as needed.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
liangjing's avatar
v1  
liangjing committed
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
    assert isinstance(model, list), "interleaved pipeline parallelism expected model chunking"
    assert all(isinstance(chunk, torch.nn.Module) for chunk in model), "invalid model chunking"
    assert isinstance(
        data_iterator, list
    ), "interleaved pipeline parallelism expected each model chunk to have a data iterator"

    config = get_model_config(model[0])
    if config.overlap_p2p_comm and config.batch_p2p_comm:
        raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")

    # Disable async grad reductions
    no_sync_func = config.no_sync_func
    if no_sync_func is None and all(isinstance(chunk, torchDDP) for chunk in model):

        def multi_no_sync():
            stack = contextlib.ExitStack()
            for chunk in model:
                stack.enter_context(chunk.no_sync())
            return stack

        no_sync_func = multi_no_sync
    if no_sync_func is None:
        no_sync_func = contextlib.nullcontext
    no_sync_context = None

    def disable_grad_sync():
        """Disable asynchronous grad reductions"""
        nonlocal no_sync_context
        if no_sync_context is None:
            no_sync_context = no_sync_func()
            no_sync_context.__enter__()

    def enable_grad_sync():
        """Enable asynchronous grad reductions"""
        nonlocal no_sync_context
        if no_sync_context is not None:
            no_sync_context.__exit__(None, None, None)
            no_sync_context = None

    disable_grad_sync()

    # Model chunk IDs with synchronized grads
    synchronized_model_chunks = set()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
420

421
422
    input_tensors = [[] for _ in range(len(model))]
    output_tensors = [[] for _ in range(len(model))]
423
    forward_data_store = []
424
425
426
    if not forward_only:
        output_tensor_grads = [[] for _ in range(len(model))]

427
428
429
430
431
432
433
434
435
436
437
438
439
    pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
    pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()

    if num_microbatches % pipeline_parallel_size != 0:
        msg = f'number of microbatches ({num_microbatches}) is not divisible by '
        msg += f'pipeline-model-parallel-size ({pipeline_parallel_size}) '
        msg += 'when using interleaved schedule'
        raise RuntimeError(msg)

    model_type = get_model_type(model[0])
    if model_type == ModelType.encoder_and_decoder:
        raise RuntimeError("Interleaving is not supported with an encoder and decoder model.")

liangjing's avatar
v1  
liangjing committed
440
441
442
    if decoder_seq_length is not None and decoder_seq_length != seq_length:
        raise RuntimeError(
            "Interleaving is not supported with a different decoder sequence length."
443
        )
444

liangjing's avatar
v1  
liangjing committed
445
446
447
448
    tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
    if config.sequence_parallel:
        tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()

449
450
    # Compute number of warmup and remaining microbatches.
    num_model_chunks = len(model)
451
    total_num_microbatches = num_microbatches * num_model_chunks
452
453
    all_warmup_microbatches = False
    if forward_only:
454
        num_warmup_microbatches = total_num_microbatches
455
    else:
456
457
458
459
460
461
        # Run all forward passes and then all backward passes if number of
        # microbatches is just the number of pipeline stages.
        # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
        # all workers, followed by more microbatches after depending on
        # stage ID (more forward passes for earlier stages, later stages can
        # immediately start with 1F1B).
462
463
        if num_microbatches == pipeline_parallel_size:
            num_warmup_microbatches = total_num_microbatches
464
465
            all_warmup_microbatches = True
        else:
liangjing's avatar
v1  
liangjing committed
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
            num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
            num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
            num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches)
    num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches

    # Checkpoint the activations of partial Transformer layers in a number of micro-batches
    # within the maximum outstanding micro-batch backpropagations.
    # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
    # checkpoint partial Transformer layers (or skip checkpointing) and
    # the rest of micro-batches within a window of micro-batches checkpoint
    # all Transformer layers. The window of micro-batches is set by the maximum
    # outstanding backpropagations and becomes smaller at later pipeline stages.
    # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
    max_outstanding_backprops = None
    if config.num_microbatches_with_partial_activation_checkpoints is not None:
        max_outstanding_backprops = num_warmup_microbatches + 1

    # Synchronize params for first two model chunks
    if config.param_sync_func is not None:
        config.param_sync_func(model[0].parameters())
        config.param_sync_func(model[1].parameters())
487

488
    def get_model_chunk_id(microbatch_id, forward):
489
        """Helper method to get the model chunk ID given the iteration number."""
490
491
        microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
        model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
492
        if not forward:
liangjing's avatar
v1  
liangjing committed
493
            model_chunk_id = num_model_chunks - model_chunk_id - 1
494
        return model_chunk_id
495

liangjing's avatar
v1  
liangjing committed
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
    def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool:
        """Check if an iteration is the first for a model chunk."""
        microbatch_group_size = pipeline_parallel_size * num_model_chunks
        num_microbatch_groups = total_num_microbatches // microbatch_group_size
        microbatch_group_id = microbatch_id // microbatch_group_size
        microbatch_id_in_group = microbatch_id % microbatch_group_size
        if microbatch_group_id == 0:
            return microbatch_id_in_group % pipeline_parallel_size == 0
        else:
            return False

    def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool:
        """Check if an iteration is the last for a model chunk."""
        microbatch_group_size = pipeline_parallel_size * num_model_chunks
        num_microbatch_groups = total_num_microbatches // microbatch_group_size
        microbatch_group_id = microbatch_id // microbatch_group_size
        microbatch_id_in_group = microbatch_id % microbatch_group_size
        if microbatch_group_id == num_microbatch_groups - 1:
            return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1
        else:
            return False

    def forward_step_helper(microbatch_id, checkpoint_activations_microbatch):
519
520
521
        """Helper method to run forward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        forward_step())."""
522
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
523
        parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
524

liangjing's avatar
v1  
liangjing committed
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
        # launch param synchronization for next model chunk
        # Note: Asynchronous communication tends to slow down compute.
        # To reduce idling from mismatched microbatch times, we launch
        # asynchronous communication at the same time across the
        # pipeline-parallel group.
        if config.param_sync_func is not None:
            param_sync_microbatch_id = microbatch_id + pipeline_parallel_rank
            if (
                param_sync_microbatch_id < total_num_microbatches
                and is_first_microbatch_for_model_chunk(param_sync_microbatch_id)
            ):
                param_sync_chunk_id = get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1
                if 1 < param_sync_chunk_id < num_model_chunks:
                    config.param_sync_func(model[param_sync_chunk_id].parameters())

540
        # forward step
541
        if parallel_state.is_pipeline_first_stage():
liangjing's avatar
v1  
liangjing committed
542
            if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
543
544
                input_tensors[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id][-1]
liangjing's avatar
v1  
liangjing committed
545
546
547
548
549
550
551
552
553
554
555
        output_tensor = forward_step(
            forward_step_func,
            data_iterator[model_chunk_id],
            model[model_chunk_id],
            num_microbatches,
            input_tensor,
            forward_data_store,
            config,
            collect_non_loss_data,
            checkpoint_activations_microbatch,
        )
556
557
        output_tensors[model_chunk_id].append(output_tensor)

558
559
560
561
562
        # if forward-only, no need to save tensors for a backward pass
        if forward_only:
            input_tensors[model_chunk_id].pop()
            output_tensors[model_chunk_id].pop()

563
564
        return output_tensor

565
    def backward_step_helper(microbatch_id):
566
567
568
        """Helper method to run backward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        backward_step())."""
569
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
570
        parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
571

liangjing's avatar
v1  
liangjing committed
572
573
574
575
576
        # launch grad synchronization (default)
        if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(microbatch_id):
            enable_grad_sync()
            synchronized_model_chunks.add(model_chunk_id)

577
        if parallel_state.is_pipeline_last_stage():
578
579
580
581
582
            if len(output_tensor_grads[model_chunk_id]) == 0:
                output_tensor_grads[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id].pop(0)
        output_tensor = output_tensors[model_chunk_id].pop(0)
        output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
liangjing's avatar
v1  
liangjing committed
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
        input_tensor_grad = backward_step(
            input_tensor, output_tensor, output_tensor_grad, model_type, config
        )

        # launch grad synchronization (custom grad sync)
        # Note: Asynchronous communication tends to slow down compute.
        # To reduce idling from mismatched microbatch times, we launch
        # asynchronous communication at the same time across the
        # pipeline-parallel group.
        if config.grad_sync_func is not None:
            grad_sync_microbatch_id = microbatch_id - pipeline_parallel_rank
            if grad_sync_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
                grad_sync_microbatch_id
            ):
                grad_sync_chunk_id = get_model_chunk_id(grad_sync_microbatch_id, forward=False)
                enable_grad_sync()
                config.grad_sync_func(model[grad_sync_chunk_id].parameters())
                synchronized_model_chunks.add(grad_sync_chunk_id)
        disable_grad_sync()
602
603
604
605

        return input_tensor_grad

    # Run warmup forward passes.
606
    parallel_state.set_virtual_pipeline_model_parallel_rank(0)
liangjing's avatar
v1  
liangjing committed
607
608
609
610
611
    input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config))

    fwd_wait_handles = None
    bwd_wait_handles = None

612
    for k in range(num_warmup_microbatches):
liangjing's avatar
v1  
liangjing committed
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627

        if fwd_wait_handles is not None:
            for req in fwd_wait_handles:
                req.wait()

        # Decide to checkpoint all layers' activations of the current micro-batch
        if max_outstanding_backprops is not None:
            checkpoint_activations_microbatch = (
                k % max_outstanding_backprops
                >= config.num_microbatches_with_partial_activation_checkpoints
            )
        else:
            checkpoint_activations_microbatch = None

        output_tensor = forward_step_helper(k, checkpoint_activations_microbatch)
628
629

        # Determine if tensor should be received from previous stage.
liangjing's avatar
v1  
liangjing committed
630
        next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
631
        recv_prev = True
632
        if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
633
634
            if next_forward_model_chunk_id == 0:
                recv_prev = False
635
        if k == (total_num_microbatches - 1):
636
            recv_prev = False
637
638

        # Don't send tensor downstream if on last stage.
639
        if parallel_state.is_pipeline_last_stage():
640
            output_tensor = None
641
642
643

        # Send and receive tensors as appropriate (send tensors computed
        # in this iteration; receive tensors for next iteration).
liangjing's avatar
v1  
liangjing committed
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
        if not config.overlap_p2p_comm:
            if (
                k == (num_warmup_microbatches - 1)
                and not forward_only
                and not all_warmup_microbatches
            ):
                input_tensor_grad = None
                recv_next = True
                if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False
                (
                    input_tensor,
                    output_tensor_grad,
                ) = p2p_communication.send_forward_backward_recv_forward_backward(
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
                    recv_next=recv_next,
                    tensor_shape=tensor_shape,
                    config=config,
                )
                output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
            else:
                input_tensor = p2p_communication.send_forward_recv_forward(
                    output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, config=config
                )
            input_tensors[next_forward_model_chunk_id].append(input_tensor)
671
        else:
liangjing's avatar
v1  
liangjing committed
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
            input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
                output_tensor,
                recv_prev=recv_prev,
                tensor_shape=tensor_shape,
                config=config,
                overlap_p2p_comm=True,
            )

            if (
                k == (num_warmup_microbatches - 1)
                and not forward_only
                and not all_warmup_microbatches
            ):
                input_tensor_grad = None
                recv_next = True
                if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

                (
                    output_tensor_grad,
                    bwd_wait_handles,
                ) = p2p_communication.send_backward_recv_backward(
                    input_tensor_grad,
                    recv_next=recv_next,
                    tensor_shape=tensor_shape,
                    config=config,
                    overlap_p2p_comm=True,
                )

                output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
            input_tensors[next_forward_model_chunk_id].append(input_tensor)

        deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
705
706
707
708
709
710

    # Run 1F1B in steady state.
    for k in range(num_microbatches_remaining):
        # Forward pass.
        forward_k = k + num_warmup_microbatches

liangjing's avatar
v1  
liangjing committed
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
        # Decide to checkpoint all layers' activations of the current micro-batch
        if max_outstanding_backprops is not None:
            checkpoint_activations_microbatch = (
                forward_k % max_outstanding_backprops
                >= config.num_microbatches_with_partial_activation_checkpoints
            )
        else:
            checkpoint_activations_microbatch = None

        if config.overlap_p2p_comm:
            if fwd_wait_handles is not None:
                for req in fwd_wait_handles:
                    req.wait()

            deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)

            output_tensor = forward_step_helper(forward_k, checkpoint_activations_microbatch)

            # Determine if current stage has anything to send in either direction,
            # otherwise set tensor to None.
            forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
            parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)

            # Last virtual stage no activation tensor to send
            if parallel_state.is_pipeline_last_stage():
                output_tensor = None

            # Determine if peers are sending, and where in data structure to put
            # received tensors.
            recv_prev = True
            if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
                # First stage is ahead of last stage by (pipeline_parallel_size - 1).
                next_forward_model_chunk_id = get_model_chunk_id(
                    forward_k - (pipeline_parallel_size - 1), forward=True
                )
                if next_forward_model_chunk_id == (num_model_chunks - 1):
                    recv_prev = False
                next_forward_model_chunk_id += 1
            else:
                next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
751

liangjing's avatar
v1  
liangjing committed
752
753
754
755
            # If last iteration, don't receive; we already received one extra
            # before the start of the for loop.
            if k == (num_microbatches_remaining - 1):
                recv_prev = False
756

liangjing's avatar
v1  
liangjing committed
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
            # Send activation tensor to the next stage and receive activation tensor from the
            # previous stage
            input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
                output_tensor,
                recv_prev=recv_prev,
                tensor_shape=tensor_shape,
                config=config,
                overlap_p2p_comm=True,
            )
            # assert fwd_wait_handles is not None

            if bwd_wait_handles is not None:
                for req in bwd_wait_handles:
                    req.wait()

            # Backward pass.
            backward_k = k
            input_tensor_grad = backward_step_helper(backward_k)

            backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
            parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)

            # First virtual stage no activation gradient tensor to send
            if parallel_state.is_pipeline_first_stage():
                input_tensor_grad = None

            # Determine if the current virtual stage has an activation gradient tensor to receive
            recv_next = True
            if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
                next_backward_model_chunk_id = get_model_chunk_id(
                    backward_k - (pipeline_parallel_size - 1), forward=False
                )
                if next_backward_model_chunk_id == 0:
                    recv_next = False
                next_backward_model_chunk_id -= 1
            else:
                next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)

            output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward(
                input_tensor_grad,
                recv_next=recv_next,
                tensor_shape=tensor_shape,
                config=config,
                overlap_p2p_comm=True,
            )

        else:  # no p2p overlap
            output_tensor = forward_step_helper(forward_k, checkpoint_activations_microbatch)

            # Backward pass.
            backward_k = k
            input_tensor_grad = backward_step_helper(backward_k)

            # Send output_tensor and input_tensor_grad, receive input_tensor
            # and output_tensor_grad.

            # Determine if current stage has anything to send in either direction,
            # otherwise set tensor to None.
            forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
            parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
            if parallel_state.is_pipeline_last_stage():
                output_tensor = None

            backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
            parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
            if parallel_state.is_pipeline_first_stage():
                input_tensor_grad = None

            # Determine if peers are sending, and where in data structure to put
            # received tensors.
            recv_prev = True
            if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
                # First stage is ahead of last stage by (pipeline_parallel_size - 1).
                next_forward_model_chunk_id = get_model_chunk_id(
                    forward_k - (pipeline_parallel_size - 1), forward=True
                )
                if next_forward_model_chunk_id == (num_model_chunks - 1):
                    recv_prev = False
                next_forward_model_chunk_id += 1
            else:
                next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
839

liangjing's avatar
v1  
liangjing committed
840
841
842
843
844
845
846
847
848
849
850
            recv_next = True
            if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
                next_backward_model_chunk_id = get_model_chunk_id(
                    backward_k - (pipeline_parallel_size - 1), forward=False
                )
                if next_backward_model_chunk_id == 0:
                    recv_next = False
                next_backward_model_chunk_id -= 1
            else:
                next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
851

liangjing's avatar
v1  
liangjing committed
852
853
854
            # If last iteration, don't receive; we already received one extra
            # before the start of the for loop.
            if k == (num_microbatches_remaining - 1):
855
856
                recv_prev = False

liangjing's avatar
v1  
liangjing committed
857
858
859
860
861
862
863
864
865
866
867
868
869
            # Communicate tensors.
            (
                input_tensor,
                output_tensor_grad,
            ) = p2p_communication.send_forward_backward_recv_forward_backward(
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
                recv_next=recv_next,
                tensor_shape=tensor_shape,
                config=config,
            )
            deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
870

871
872
        # Put input_tensor and output_tensor_grad in data structures in the
        # right location.
873
874
875
        if recv_prev:
            input_tensors[next_forward_model_chunk_id].append(input_tensor)
        if recv_next:
liangjing's avatar
v1  
liangjing committed
876
877
878
            output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)

    deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
879

880
    # Run cooldown backward passes (flush out pipeline).
881
    if not forward_only:
liangjing's avatar
v1  
liangjing committed
882
883
884
885
        if config.overlap_p2p_comm and bwd_wait_handles is not None:
            for wait_handle in bwd_wait_handles:
                wait_handle.wait()

886
        if all_warmup_microbatches:
liangjing's avatar
v1  
liangjing committed
887
888
889
            output_tensor_grads[num_model_chunks - 1].append(
                p2p_communication.recv_backward(tensor_shape, config=config)
            )
890
        for k in range(num_microbatches_remaining, total_num_microbatches):
891
            input_tensor_grad = backward_step_helper(k)
liangjing's avatar
v1  
liangjing committed
892
            next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
893
            recv_next = True
894
            if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
895
896
                if next_backward_model_chunk_id == (num_model_chunks - 1):
                    recv_next = False
897
            if k == (total_num_microbatches - 1):
898
899
                recv_next = False
            output_tensor_grads[next_backward_model_chunk_id].append(
900
                p2p_communication.send_backward_recv_backward(
liangjing's avatar
v1  
liangjing committed
901
902
903
904
905
906
907
908
909
910
911
912
913
914
                    input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, config=config
                )
            )

    # Launch any remaining grad reductions
    enable_grad_sync()
    if config.grad_sync_func is not None:
        params = []
        for model_chunk_id in range(num_model_chunks):
            if model_chunk_id not in synchronized_model_chunks:
                params.extend(model[model_chunk_id].parameters())
                synchronized_model_chunks.add(model_chunk_id)
        if params:
            config.grad_sync_func(params)
915

916
    return forward_data_store
917

liangjing's avatar
v1  
liangjing committed
918
919
920
921
922
923
924
925
926
927

def get_tensor_shapes(
    *,
    rank: int,
    model_type: ModelType,
    seq_length: int,
    micro_batch_size: int,
    decoder_seq_length: int,
    config,
):
928
929
930
931
932
933
934
935
936
    # Determine right tensor sizes (based on position of rank with respect to split
    # rank) and model size.
    # Send two tensors if model is T5 and rank is in decoder stage:
    #     first tensor is decoder (pre-transpose),
    #     second tensor is encoder (post-transpose).
    # If model is T5 and rank is at the boundary:
    #     send one tensor (post-transpose from encoder).
    # Otherwise, send one tensor (pre-transpose).
    tensor_shapes = []
937

liangjing's avatar
v1  
liangjing committed
938
    if config.sequence_parallel:
939
        seq_length = seq_length // parallel_state.get_tensor_model_parallel_world_size()
liangjing's avatar
v1  
liangjing committed
940
941
942
943
        if model_type == ModelType.encoder_and_decoder:
            decoder_seq_length = (
                decoder_seq_length // parallel_state.get_tensor_model_parallel_world_size()
            )
944
945

    if model_type == ModelType.encoder_and_decoder:
946
        if parallel_state.is_pipeline_stage_before_split(rank):
liangjing's avatar
v1  
liangjing committed
947
            tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
948
        else:
liangjing's avatar
v1  
liangjing committed
949
950
            tensor_shapes.append((decoder_seq_length, micro_batch_size, config.hidden_size))
            tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
951
    else:
liangjing's avatar
v1  
liangjing committed
952
        tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
953
954
955
    return tensor_shapes


liangjing's avatar
v1  
liangjing committed
956
def recv_forward(tensor_shapes, config):
957
958
959
960
961
    input_tensors = []
    for tensor_shape in tensor_shapes:
        if tensor_shape is None:
            input_tensors.append(None)
        else:
liangjing's avatar
v1  
liangjing committed
962
            input_tensors.append(p2p_communication.recv_forward(tensor_shape, config))
963
964
965
    return input_tensors


liangjing's avatar
v1  
liangjing committed
966
def recv_backward(tensor_shapes, config):
967
968
969
970
971
    output_tensor_grads = []
    for tensor_shape in tensor_shapes:
        if tensor_shape is None:
            output_tensor_grads.append(None)
        else:
liangjing's avatar
v1  
liangjing committed
972
            output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, config))
973
974
975
    return output_tensor_grads


liangjing's avatar
v1  
liangjing committed
976
def send_forward(output_tensors, tensor_shapes, config):
977
978
979
980
981
    if not isinstance(output_tensors, list):
        output_tensors = [output_tensors]
    for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
        if tensor_shape is None:
            continue
liangjing's avatar
v1  
liangjing committed
982
        p2p_communication.send_forward(output_tensor, config)
983
984


liangjing's avatar
v1  
liangjing committed
985
def send_backward(input_tensor_grads, tensor_shapes, config):
986
987
988
989
990
    if not isinstance(input_tensor_grads, list):
        input_tensor_grads = [input_tensor_grads]
    for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
        if tensor_shape is None:
            continue
liangjing's avatar
v1  
liangjing committed
991
        p2p_communication.send_backward(input_tensor_grad, config)
992
993


liangjing's avatar
v1  
liangjing committed
994
def send_forward_recv_backward(output_tensors, tensor_shapes, config):
995
996
997
998
999
1000
1001
1002
    if not isinstance(output_tensors, list):
        output_tensors = [output_tensors]
    output_tensor_grads = []
    for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
        if tensor_shape is None:
            output_tensor_grads.append(None)
            continue
        output_tensor_grad = p2p_communication.send_forward_recv_backward(
liangjing's avatar
v1  
liangjing committed
1003
1004
            output_tensor, tensor_shape, config
        )
1005
1006
1007
1008
        output_tensor_grads.append(output_tensor_grad)
    return output_tensor_grads


liangjing's avatar
v1  
liangjing committed
1009
def send_backward_recv_forward(input_tensor_grads, tensor_shapes, config):
1010
1011
1012
1013
1014
1015
1016
1017
    if not isinstance(input_tensor_grads, list):
        input_tensor_grads = [input_tensor_grads]
    input_tensors = []
    for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
        if tensor_shape is None:
            input_tensors.append(None)
            continue
        input_tensor = p2p_communication.send_backward_recv_forward(
liangjing's avatar
v1  
liangjing committed
1018
1019
            input_tensor_grad, tensor_shape, config
        )
1020
1021
1022
1023
        input_tensors.append(input_tensor)
    return input_tensors


liangjing's avatar
v1  
liangjing committed
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
def forward_backward_pipelining_without_interleaving(
    *,
    forward_step_func,
    data_iterator: Union[Iterator, List[Iterator]],
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    num_microbatches: int,
    seq_length: int,
    micro_batch_size: int,
    decoder_seq_length: int = None,
    forward_only: bool = False,
    collect_non_loss_data: bool = False,
):
1036
1037
1038
1039
    """Run non-interleaved 1F1B schedule, with communication between pipeline
    stages.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
1040

liangjing's avatar
v1  
liangjing committed
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
    if isinstance(model, list):
        assert (
            len(model) == 1
        ), "non-interleaved pipeline parallelism does not support model chunking"
        model = model[0]
    if isinstance(data_iterator, list):
        assert (
            len(data_iterator) == 1
        ), "non-pipeline-parallel schedule does not support model chunking"
        data_iterator = data_iterator[0]

    config = get_model_config(model)
    if config.overlap_p2p_comm:
        raise ValueError(
            "Non-interleaved pipeline parallelism does not support overlapping p2p communication"
        )

    # Disable async grad reductions
    no_sync_func = config.no_sync_func
    if no_sync_func is None and isinstance(model, torchDDP):
        no_sync_func = model.no_sync
    if no_sync_func is None:
        no_sync_func = contextlib.nullcontext
    no_sync_context = None

    def disable_grad_sync():
        """Disable asynchronous grad reductions"""
        nonlocal no_sync_context
        if no_sync_context is None:
            no_sync_context = no_sync_func()
            no_sync_context.__enter__()

    def enable_grad_sync():
        """Enable asynchronous grad reductions"""
        nonlocal no_sync_context
        if no_sync_context is not None:
            no_sync_context.__exit__(None, None, None)
            no_sync_context = None

    disable_grad_sync()
1081
1082

    # Compute number of warmup microbatches.
liangjing's avatar
v1  
liangjing committed
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
    num_warmup_microbatches = (
        parallel_state.get_pipeline_model_parallel_world_size()
        - parallel_state.get_pipeline_model_parallel_rank()
        - 1
    )
    num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
    num_microbatches_remaining = num_microbatches - num_warmup_microbatches

    # Checkpoint the activations of partial Transformer layers in a number of micro-batches
    # within the maximum outstanding micro-batch backpropagations.
    # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
    # checkpoint partial Transformer layers (or skip checkpointing) and
    # the rest of micro-batches within a window of micro-batches checkpoint
    # all Transformer layers. The window of micro-batches is set by the maximum
    # outstanding backpropagations and becomes smaller at later pipeline stages.
    # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
    max_outstanding_backprops = None
    if config.num_microbatches_with_partial_activation_checkpoints is not None:
        max_outstanding_backprops = num_warmup_microbatches + 1
1102

1103
1104
1105
    model_type = get_model_type(model)

    rank = parallel_state.get_pipeline_model_parallel_rank()
liangjing's avatar
v1  
liangjing committed
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
    recv_tensor_shapes = get_tensor_shapes(
        rank=rank - 1,
        model_type=model_type,
        seq_length=seq_length,
        micro_batch_size=micro_batch_size,
        decoder_seq_length=decoder_seq_length,
        config=config,
    )
    send_tensor_shapes = get_tensor_shapes(
        rank=rank,
        model_type=model_type,
        seq_length=seq_length,
        micro_batch_size=micro_batch_size,
        decoder_seq_length=decoder_seq_length,
        config=config,
    )
1122

1123
1124
1125
1126
1127
1128
    # Input, output tensors only need to be saved when doing backward passes
    input_tensors = None
    output_tensors = None
    if not forward_only:
        input_tensors = []
        output_tensors = []
1129
    forward_data_store = []
1130
1131
1132

    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
liangjing's avatar
v1  
liangjing committed
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
        # Decide to checkpoint all layers' activations of the current micro-batch
        if max_outstanding_backprops is not None:
            checkpoint_activations_microbatch = (
                i % max_outstanding_backprops
                >= config.num_microbatches_with_partial_activation_checkpoints
            )
        else:
            checkpoint_activations_microbatch = None

        input_tensor = recv_forward(recv_tensor_shapes, config)
        output_tensor = forward_step(
            forward_step_func,
            data_iterator,
            model,
            num_microbatches,
            input_tensor,
            forward_data_store,
            config,
            collect_non_loss_data,
            checkpoint_activations_microbatch,
        )
        send_forward(output_tensor, send_tensor_shapes, config)
1155

1156
1157
1158
        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
liangjing's avatar
v1  
liangjing committed
1159
            deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
1160
1161
1162
1163
1164

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
liangjing's avatar
v1  
liangjing committed
1165
        input_tensor = recv_forward(recv_tensor_shapes, config)
1166
1167
1168

    # Run 1F1B in steady state.
    for i in range(num_microbatches_remaining):
liangjing's avatar
v1  
liangjing committed
1169
        last_iteration = i == (num_microbatches_remaining - 1)
1170

liangjing's avatar
v1  
liangjing committed
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
        # Decide to checkpoint all layers' activations of the current micro-batch
        if max_outstanding_backprops is not None:
            checkpoint_activations_microbatch = (
                (i + num_warmup_microbatches) % max_outstanding_backprops
            ) >= config.num_microbatches_with_partial_activation_checkpoints
        else:
            checkpoint_activations_microbatch = None

        output_tensor = forward_step(
            forward_step_func,
            data_iterator,
            model,
            num_microbatches,
            input_tensor,
            forward_data_store,
            config,
            collect_non_loss_data,
            checkpoint_activations_microbatch,
        )
1190

1191
        if forward_only:
liangjing's avatar
v1  
liangjing committed
1192
            send_forward(output_tensor, send_tensor_shapes, config)
1193
1194

            if not last_iteration:
liangjing's avatar
v1  
liangjing committed
1195
                input_tensor = recv_forward(recv_tensor_shapes, config)
1196

1197
        else:
liangjing's avatar
v1  
liangjing committed
1198
1199
1200
            output_tensor_grad = send_forward_recv_backward(
                output_tensor, send_tensor_shapes, config
            )
1201

1202
1203
1204
            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
liangjing's avatar
v1  
liangjing committed
1205
            deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
1206

1207
1208
1209
1210
            # Pop input_tensor and output_tensor from the start of the list for
            # the backward pass.
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)
1211

liangjing's avatar
v1  
liangjing committed
1212
1213
1214
            input_tensor_grad = backward_step(
                input_tensor, output_tensor, output_tensor_grad, model_type, config
            )
1215
1216
1217

            if last_iteration:
                input_tensor = None
liangjing's avatar
v1  
liangjing committed
1218
                send_backward(input_tensor_grad, recv_tensor_shapes, config)
1219
            else:
liangjing's avatar
v1  
liangjing committed
1220
1221
1222
                input_tensor = send_backward_recv_forward(
                    input_tensor_grad, recv_tensor_shapes, config
                )
1223
1224
1225
1226

    # Run cooldown backward passes.
    if not forward_only:
        for i in range(num_warmup_microbatches):
liangjing's avatar
v1  
liangjing committed
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236

            # Enable async grad reduction in the last backward pass
            # Note: If grad sync function is provided, only enable
            # async grad reduction in first pipeline stage. Other
            # pipeline stages do grad reduction during pipeline
            # bubble.
            if i == num_warmup_microbatches - 1:
                if config.grad_sync_func is None or rank == 0:
                    enable_grad_sync()

1237
1238
1239
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

liangjing's avatar
v1  
liangjing committed
1240
1241
1242
1243
1244
            output_tensor_grad = recv_backward(send_tensor_shapes, config)

            input_tensor_grad = backward_step(
                input_tensor, output_tensor, output_tensor_grad, model_type, config
            )
1245

liangjing's avatar
v1  
liangjing committed
1246
            send_backward(input_tensor_grad, recv_tensor_shapes, config)
1247

liangjing's avatar
v1  
liangjing committed
1248
1249
1250
1251
1252
    # Launch any remaining grad reductions
    if no_sync_context is not None:
        enable_grad_sync()
        if config.grad_sync_func is not None:
            config.grad_sync_func(model.parameters())
1253

1254
    return forward_data_store