schedules.py 78.8 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2

liangjing's avatar
v1  
liangjing committed
3
import contextlib
xingjinliang's avatar
xingjinliang committed
4
from typing import Iterator, List, Union
5

6
import torch
7
from torch.autograd.variable import Variable
8

9
10
from megatron.core import parallel_state
from megatron.core.enums import ModelType
liangjing's avatar
v1  
liangjing committed
11
from megatron.core.pipeline_parallel import p2p_communication
xingjinliang's avatar
xingjinliang committed
12
13
14
15
16
17
18
19
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.utils import (
    drain_embedding_wgrad_compute,
    get_attr_wrapped_model,
    get_model_config,
    get_model_type,
    get_model_xattn,
)
20

21
22
# Types
Shape = Union[List[int], torch.Size]
23

liangjing's avatar
v1  
liangjing committed
24

Jared Casper's avatar
Jared Casper committed
25
def get_forward_backward_func():
26
27
28
29
30
31
32
33
    """Retrieves the appropriate forward_backward function given the
    configuration of parallel_state.

    Returns a function that will perform all of the forward and
    backward passes of the model given the pipeline model parallel
    world size and virtual pipeline model parallel world size in the
    global parallel_state.

liangjing's avatar
v1  
liangjing committed
34
35
36
37
    Note that if using sequence parallelism, the sequence length component of
    the tensor shape is updated to original_sequence_length /
    tensor_model_parallel_world_size.

38
39
40
41
42
43
44
45
    The function returned takes the following arguments:

    forward_step_func (required): A function that takes a data
        iterator and a model as its arguments and return the model's
        forward output and the loss function. The loss function should
        take one torch.Tensor and return a torch.Tensor of loss and a
        dictionary of string -> torch.Tensor.

liangjing's avatar
v1  
liangjing committed
46
47
48
49
50
51
52
        A third argument, checkpoint_activations_microbatch, indicates
        that the activations for this microbatch should be
        checkpointed. A None value for this argument indicates that
        the default from the configuration should be used. This is
        used when the
        num_microbatches_with_partial_activation_checkpoints is used.

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        For example:

        def loss_func(loss_mask, output_tensor):
            losses = output_tensor.float()
            loss_mask = loss_mask.view(-1).float()
            loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

            # Reduce loss for logging.
            averaged_loss = average_losses_across_data_parallel_group([loss])

            return loss, {'lm loss': averaged_loss[0]}

        def forward_step(data_iterator, model):
            data, loss_mask = next(data_iterator)
            output = model(data)
            return output, partial(loss_func, loss_mask)


        forward_backward_func(forward_step_func=forward_step, ...)


    data_iterator (required): an iterator over the data, will be
liangjing's avatar
v1  
liangjing committed
75
76
        passed as is to forward_step_func. Expected to be a list of
        iterators in the case of interleaved pipeline parallelism.
77

liangjing's avatar
v1  
liangjing committed
78
79
    model (required): the actual model. Expected to be a list of modules in the case of interleaved
        pipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule.
80
81
82
83

    num_microbatches (int, required):
        The number of microbatches to go through

liangjing's avatar
v1  
liangjing committed
84
85
86
87
    seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack
        transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths
        in the config is True. Otherwise, each microbatch in the current global batch size must use
        this sequence length.
88

liangjing's avatar
v1  
liangjing committed
89
    micro_batch_size (int, required): The number of sequences in a microbatch.
90

liangjing's avatar
v1  
liangjing committed
91
92
    decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stack
        transformer. This is ignored for a single-stack transformer.
93

liangjing's avatar
v1  
liangjing committed
94
    forward_only (optional, default = False): Perform only the forward step
95

liangjing's avatar
v1  
liangjing committed
96
    collect_non_loss_data (optional, bool, default=False): TODO
Abhinav Khattar's avatar
Abhinav Khattar committed
97

xingjinliang's avatar
xingjinliang committed
98
99
100
101
    first_val_step (bool, optional): Is the first step of the validation phase. Used by
        Transformer Engine modules to only update their fp8 weights only on the first validation
        step.

102
103
104
105
    """
    pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
    if pipeline_model_parallel_size > 1:
        if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
Jared Casper's avatar
Jared Casper committed
106
107
108
109
110
111
112
            forward_backward_func = forward_backward_pipelining_with_interleaving
        else:
            forward_backward_func = forward_backward_pipelining_without_interleaving
    else:
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func

liangjing's avatar
v1  
liangjing committed
113
114

def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
115
    '''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
116
117
118
119
120

    This method should be called right after the output tensor has been
    sent to the next pipeline stage. At this point, the output tensor is
    only useful for its '.grad_fn' field, and not its '.data'.
    '''
liangjing's avatar
v1  
liangjing committed
121
    if (out is None) or (not deallocate_pipeline_outputs):
Lawrence McAfee's avatar
Lawrence McAfee committed
122
        return
liangjing's avatar
v1  
liangjing committed
123
124
    assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
    assert out._base is None, "counter-productive to free a view of another tensor."
xingjinliang's avatar
xingjinliang committed
125
    out.data = torch.empty((1,), device=out.device, dtype=out.dtype)
liangjing's avatar
v1  
liangjing committed
126

127

128
def custom_backward(output, grad_output):
129
130
    '''Directly call C++ autograd engine.

131
    To make the 'deallocate_output_tensor' (above) optimization work, the C++
132
133
134
135
    autograd engine must be called directly, bypassing Pytorch's
    torch.autograd.backward. Pytorch's 'backward' checks that the output and
    grad have the same shape, while C++'s 'backward' does not.
    '''
136

liangjing's avatar
v1  
liangjing committed
137
138
139
    assert output.numel() == 1, "output should be pseudo-'freed' in schedule, to optimize memory"
    assert isinstance(output, torch.Tensor), "output == '%s'." % type(output).__name__
    assert isinstance(grad_output, (torch.Tensor, type(None))), (
140
        "grad_output == '%s'." % type(grad_output).__name__
liangjing's avatar
v1  
liangjing committed
141
    )
142
143
144
145

    # Handle scalar output
    if grad_output is None:
        assert output.numel() == 1, "implicit grad requires scalar output."
xingjinliang's avatar
xingjinliang committed
146
        grad_output = torch.ones_like(output, memory_format=torch.preserve_format)
147
148

    # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Lawrence McAfee's avatar
Lawrence McAfee committed
149
    Variable._execution_engine.run_backward(
liangjing's avatar
v1  
liangjing committed
150
151
152
153
154
        tensors=(output,),
        grad_tensors=(grad_output,),
        keep_graph=False,
        create_graph=False,
        inputs=tuple(),
Lawrence McAfee's avatar
Lawrence McAfee committed
155
156
157
        allow_unreachable=True,
        accumulate_grad=True,
    )
158
159


xingjinliang's avatar
xingjinliang committed
160
161
162
163
164
165
166
167
168
169
170
171
def set_current_microbatch(model, microbatch_id):
    """Set the current microbatch."""
    decoder_exists = True
    decoder = None
    try:
        decoder = get_attr_wrapped_model(model, "decoder")
    except RuntimeError:
        decoder_exists = False
    if decoder_exists and decoder is not None:
        decoder.current_microbatch = microbatch_id


liangjing's avatar
v1  
liangjing committed
172
173
174
175
176
177
178
179
180
181
def forward_step(
    forward_step_func,
    data_iterator,
    model,
    num_microbatches,
    input_tensor,
    forward_data_store,
    config,
    collect_non_loss_data=False,
    checkpoint_activations_microbatch=None,
xingjinliang's avatar
xingjinliang committed
182
183
184
    is_first_microbatch=False,
    current_microbatch=None,
    encoder_decoder_xattn=False,
liangjing's avatar
v1  
liangjing committed
185
):
186
187
    """Forward step for passed-in model.

xingjinliang's avatar
xingjinliang committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    If it is the first stage, the input tensor is obtained from the data_iterator.
    Otherwise, the passed-in input_tensor is used.

    Args:
        forward_step_func (callable):
            The forward step function for the model that takes the
            data iterator as the first argument, and model as the second.
            This user's forward step is expected to output a tuple of two elements:

                1. The output object from the forward step. This output object needs to be a
                    tensor or some kind of collection of tensors. The only hard requirement
                    for this object is that it needs to be acceptible as input into the second
                    function.
                2. A function to reduce (optionally) the output from the forward step. This
                    could be a reduction over the loss from the model, it could be a function that
                    grabs the output from the model and reformats, it could be a function that just
                    passes through the model output. This function must have one of the following
                    patterns, and depending on the pattern different things happen internally:

                        a. A tuple of reduced loss and some other data. Note that in this case
                            the first argument is divided by the number of global microbatches,
                            assuming it is a loss, so that the loss is stable as a function of
                            the number of devices the step is split across.
                        b. A triple of reduced loss, number of tokens, and some other data. This
                            is similar to case (a), but the loss is further averaged across the
                            number of tokens in the batch. If the user is not already averaging
                            across the number of tokens, this pattern is useful to use.
                        c. Any arbitrary data the user wants (eg a dictionary of tensors, a list
                            of tensors, etc in the case of inference). To trigger case 3 you need
                            to specify `collect_non_loss_data=True` and you may also want to
                            specify `forward_only=True` in the call to the parent forward_backward
                            function.
        data_iterator (iterator):
            The data iterator.
        model (nn.Module):
            The model to perform the forward step on.
        num_microbatches (int):
            The number of microbatches.
        input_tensor (Tensor or list[Tensor]):
            The input tensor(s) for the forward step.
        forward_data_store (list):
            The list to store the forward data. If you go down path 2.a or
            2.b for the return of your forward reduction function then this will store only the
            final dimension of the output, for example the metadata output by the loss function.
            If you go down the path of 2.c then this will store the entire output of the forward
            reduction function applied to the model output.
        config (object):
            The configuration object.
        collect_non_loss_data (bool, optional):
            Whether to collect non-loss data. Defaults to False.
            This is the path to use if you want to collect arbitrary output from the model forward,
            such as with inference use cases. Defaults to False.
        checkpoint_activations_microbatch (int, optional):
            The microbatch to checkpoint activations.
            Defaults to None.
        is_first_microbatch (bool, optional):
            Whether it is the first microbatch. Defaults to False.
        current_microbatch (int, optional):
            The current microbatch. Defaults to None.

    Returns:
        Tensor or list[Tensor]: The output object(s) from the forward step.
        Tensor: The number of tokens.
    """
liangjing's avatar
v1  
liangjing committed
252
253
    if config.timers is not None:
        config.timers('forward-compute', log_level=2).start()
254

xingjinliang's avatar
xingjinliang committed
255
256
257
258
259
    if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):
        model.set_is_first_microbatch()
    if current_microbatch is not None:
        set_current_microbatch(model, current_microbatch)

260
261
262
263
264
    unwrap_output_tensor = False
    if not isinstance(input_tensor, list):
        input_tensor = [input_tensor]
        unwrap_output_tensor = True

265
266
267
    set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
    set_input_tensor(input_tensor)

liangjing's avatar
v1  
liangjing committed
268
269
270
271
    if config.enable_autocast:
        context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
    else:
        context_manager = contextlib.nullcontext()
272
    with context_manager:
liangjing's avatar
v1  
liangjing committed
273
274
275
276
277
278
        if checkpoint_activations_microbatch is None:
            output_tensor, loss_func = forward_step_func(data_iterator, model)
        else:
            output_tensor, loss_func = forward_step_func(
                data_iterator, model, checkpoint_activations_microbatch
            )
279

xingjinliang's avatar
xingjinliang committed
280
    num_tokens = torch.tensor(0, dtype=torch.int)
281
    if parallel_state.is_pipeline_last_stage():
282
        if not collect_non_loss_data:
xingjinliang's avatar
xingjinliang committed
283
284
285
286
287
288
289
290
291
292
293
            outputs = loss_func(output_tensor)
            if len(outputs) == 3:
                output_tensor, num_tokens, loss_reduced = outputs
                if not config.calculate_per_token_loss:
                    output_tensor /= num_tokens
                    output_tensor /= num_microbatches
            else:
                # preserve legacy loss averaging behavior (ie, over the number of microbatches)
                assert len(outputs) == 2
                output_tensor, loss_reduced = outputs
                output_tensor /= num_microbatches
294
295
296
297
298
            forward_data_store.append(loss_reduced)
        else:
            data = loss_func(output_tensor, non_loss_data=True)
            forward_data_store.append(data)

liangjing's avatar
v1  
liangjing committed
299
300
    if config.timers is not None:
        config.timers('forward-compute').stop()
301

xingjinliang's avatar
xingjinliang committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    # Set the loss scale for the auxiliary loss of the MoE layer.
    # Since we use a trick to do backward on the auxiliary loss, we need to set the scale
    # explicitly.
    if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
        # Calculate the loss scale based on the grad_scale_func if available, else default to 1.
        loss_scale = (
            config.grad_scale_func(torch.ones(1, device=output_tensor.device))
            if config.grad_scale_func is not None
            else torch.tensor(1.0)
        )
        # Set the loss scale
        MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)

    # If T5 model and in decoder stack, then send encoder_hidden_state
316
    # downstream as well.
317
    model_type = get_model_type(model)
liangjing's avatar
v1  
liangjing committed
318
    if (
xingjinliang's avatar
xingjinliang committed
319
320
321
        model_type == ModelType.encoder_and_decoder
        and encoder_decoder_xattn
        and parallel_state.is_inside_decoder()
liangjing's avatar
v1  
liangjing committed
322
    ):
xingjinliang's avatar
xingjinliang committed
323
324
        return [output_tensor, input_tensor[-1]], num_tokens

325
    if unwrap_output_tensor:
xingjinliang's avatar
xingjinliang committed
326
327
        return output_tensor, num_tokens
    return [output_tensor], num_tokens
328
329


liangjing's avatar
v1  
liangjing committed
330
def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config):
331
332
333
334
335
336
337
    """Backward step through passed-in output tensor.

    If last stage, output_tensor_grad is None, otherwise gradient of loss
    with respect to stage's output tensor.

    Returns gradient of loss with respect to input tensor (None if first
    stage)."""
338
339
340
341

    # NOTE: This code currently can handle at most one skip connection. It
    # needs to be modified slightly to support arbitrary numbers of skip
    # connections.
342

liangjing's avatar
v1  
liangjing committed
343
344
    if config.timers is not None:
        config.timers('backward-compute', log_level=2).start()
345
346

    # Retain the grad on the input_tensor.
347
348
349
350
351
352
353
354
355
356
357
358
    unwrap_input_tensor_grad = False
    if not isinstance(input_tensor, list):
        input_tensor = [input_tensor]
        unwrap_input_tensor_grad = True
    for x in input_tensor:
        if x is not None:
            x.retain_grad()

    if not isinstance(output_tensor, list):
        output_tensor = [output_tensor]
    if not isinstance(output_tensor_grad, list):
        output_tensor_grad = [output_tensor_grad]
359
360

    # Backward pass.
liangjing's avatar
v1  
liangjing committed
361
362
363
364
365
366
367
    if output_tensor_grad[0] is None and config.grad_scale_func is not None:
        output_tensor[0] = config.grad_scale_func(output_tensor[0])

    if config.deallocate_pipeline_outputs:
        custom_backward(output_tensor[0], output_tensor_grad[0])
    else:
        torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
368
369

    # Collect the grad of the input_tensor.
370
    input_tensor_grad = [None]
371
    if input_tensor is not None:
372
373
374
375
376
377
378
379
380
        input_tensor_grad = []
        for x in input_tensor:
            if x is None:
                input_tensor_grad.append(None)
            else:
                input_tensor_grad.append(x.grad)

    # Handle single skip connection if it exists (encoder_hidden_state in
    # model with encoder and decoder).
liangjing's avatar
v1  
liangjing committed
381
382
383
    if (
        parallel_state.get_pipeline_model_parallel_world_size() > 1
        and model_type == ModelType.encoder_and_decoder
xingjinliang's avatar
xingjinliang committed
384
        and len(output_tensor_grad) > 1  # excludes models that lack a skip connection.
liangjing's avatar
v1  
liangjing committed
385
    ):
386
        if output_tensor_grad[1] is not None:
xingjinliang's avatar
xingjinliang committed
387
            assert input_tensor_grad[-1] is not None
388
389
390
            input_tensor_grad[-1].add_(output_tensor_grad[1])
    if unwrap_input_tensor_grad:
        input_tensor_grad = input_tensor_grad[0]
391

liangjing's avatar
v1  
liangjing committed
392
393
    if config.timers is not None:
        config.timers('backward-compute').stop()
394
395
396
397

    return input_tensor_grad


xingjinliang's avatar
xingjinliang committed
398
399
400
401
402
403
404
405
def check_first_val_step(first_val_step, forward_only, cond):
    """Check if it is the first validation step."""
    if (first_val_step is not None) and forward_only:
        return first_val_step and cond
    else:
        return cond


liangjing's avatar
v1  
liangjing committed
406
407
408
409
410
411
412
413
414
415
416
def forward_backward_no_pipelining(
    *,
    forward_step_func,
    data_iterator: Union[Iterator, List[Iterator]],
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    num_microbatches: int,
    seq_length: int,  # unused
    micro_batch_size: int,  # unused
    decoder_seq_length: int = None,  # unused
    forward_only: bool = False,
    collect_non_loss_data: bool = False,
xingjinliang's avatar
xingjinliang committed
417
    first_val_step: bool = None,
liangjing's avatar
v1  
liangjing committed
418
):
419
420
421
    """Run forward and backward passes with no pipeline parallelism
    (no inter-stage communication).

422
423
424
425
426
    Returns dictionary with losses.


    See get_forward_backward_func() for argument details
    """
427

liangjing's avatar
v1  
liangjing committed
428
429
430
431
432
433
434
435
436
437
    if isinstance(model, list):
        assert len(model) == 1, "non-pipeline-parallel schedule does not support model chunking"
        model = model[0]
    if isinstance(data_iterator, list):
        assert (
            len(data_iterator) == 1
        ), "non-pipeline-parallel schedule does not support model chunking"
        data_iterator = data_iterator[0]

    config = get_model_config(model)
xingjinliang's avatar
xingjinliang committed
438
439
    if config.timers is not None:
        config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
liangjing's avatar
v1  
liangjing committed
440
441
442
443

    no_sync_func = config.no_sync_func
    if no_sync_func is None:
        no_sync_func = contextlib.nullcontext
444

445
446
    model_type = get_model_type(model)

447
    forward_data_store = []
448
    input_tensor, output_tensor_grad = None, None
xingjinliang's avatar
xingjinliang committed
449
    total_num_tokens = torch.zeros([], dtype=torch.int, device="cuda")
liangjing's avatar
v1  
liangjing committed
450
    with no_sync_func():
451
        for i in range(num_microbatches - 1):
xingjinliang's avatar
xingjinliang committed
452
            output_tensor, num_tokens = forward_step(
liangjing's avatar
v1  
liangjing committed
453
454
455
456
457
458
459
460
                forward_step_func,
                data_iterator,
                model,
                num_microbatches,
                input_tensor,
                forward_data_store,
                config,
                collect_non_loss_data,
xingjinliang's avatar
xingjinliang committed
461
462
                is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0),
                current_microbatch=i,
liangjing's avatar
v1  
liangjing committed
463
            )
xingjinliang's avatar
xingjinliang committed
464
            total_num_tokens += num_tokens.item()
465
            if not forward_only:
liangjing's avatar
v1  
liangjing committed
466
                backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
467
468
469

    # Run computation for last microbatch out of context handler (want to
    # synchronize gradients).
xingjinliang's avatar
xingjinliang committed
470
    output_tensor, num_tokens = forward_step(
liangjing's avatar
v1  
liangjing committed
471
472
473
474
475
476
477
478
        forward_step_func,
        data_iterator,
        model,
        num_microbatches,
        input_tensor,
        forward_data_store,
        config,
        collect_non_loss_data,
xingjinliang's avatar
xingjinliang committed
479
480
481
482
        is_first_microbatch=check_first_val_step(
            first_val_step, forward_only, num_microbatches == 1
        ),
        current_microbatch=num_microbatches - 1,
liangjing's avatar
v1  
liangjing committed
483
    )
xingjinliang's avatar
xingjinliang committed
484
    total_num_tokens += num_tokens.item()
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
485

486
    if not forward_only:
liangjing's avatar
v1  
liangjing committed
487
        backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
488

xingjinliang's avatar
xingjinliang committed
489
490
491
492
493
494
495
496
497
498
    if config.finalize_model_grads_func is not None and not forward_only:
        # Finalize model grads (perform full grad all-reduce / reduce-scatter for
        # data parallelism and layernorm all-reduce for sequence parallelism).
        config.finalize_model_grads_func(
            [model], total_num_tokens if config.calculate_per_token_loss else None
        )

    if config.timers is not None:
        config.timers('forward-backward').stop()

499
    return forward_data_store
500
501


xingjinliang's avatar
xingjinliang committed
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
def clear_embedding_activation_buffer(config, model):
    """Clear embedding activation buffer."""

    if (
        parallel_state.is_pipeline_last_stage(ignore_virtual=True)
        and config.defer_embedding_wgrad_compute
    ):
        if isinstance(model, list):
            embedding_module = get_attr_wrapped_model(
                model[-1], 'post_process', return_model_obj=True
            )
        else:
            embedding_module = get_attr_wrapped_model(model, 'post_process', return_model_obj=True)

        # Need to ensure no stray activations exists in this buffer
        embedding_module.embedding_activation_buffer.clear()

        return embedding_module
    else:
        return None


def finish_embedding_wgrad_compute(config, embedding_module):
    """Finish embedding wgrad compute."""
    if (
        parallel_state.is_pipeline_last_stage(ignore_virtual=True)
        and config.defer_embedding_wgrad_compute
    ):
        embedding_activation_buffer = embedding_module.embedding_activation_buffer
        grad_output_buffer = embedding_module.grad_output_buffer
        weight = (
            embedding_module.output_layer.weight
            if embedding_module.share_embeddings_and_output_weights
            else embedding_module.shared_embedding_or_output_weight()
        )

        drain_embedding_wgrad_compute(
            config, embedding_activation_buffer, grad_output_buffer, weight
        )


liangjing's avatar
v1  
liangjing committed
543
544
545
546
547
548
549
550
551
552
553
def forward_backward_pipelining_with_interleaving(
    *,
    forward_step_func,
    data_iterator: Union[Iterator, List[Iterator]],
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    num_microbatches: int,
    seq_length: int,
    micro_batch_size: int,
    decoder_seq_length: int = None,
    forward_only: bool = False,
    collect_non_loss_data: bool = False,
xingjinliang's avatar
xingjinliang committed
554
    first_val_step: bool = None,
liangjing's avatar
v1  
liangjing committed
555
):
556
557
558
559
    """Run interleaved 1F1B schedule (model split into model chunks), with
    communication between pipeline stages as needed.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
xingjinliang's avatar
xingjinliang committed
560
561
562
563
564
565
566
567
568
569

    # Convention used in this function:
    # num_microbatches for number of microbatches per pipeline stage;
    # num_model_chunks for virtual pipeline size;
    # then total_num_microbatches = num_microbatches * num_model_chunks.
    # Their corresponding index variables are
    # microbatch_id in [0, num_microbatches)
    # model_chunk_id in [0, num_model_chunks)
    # virtual_microbatch_id in [0, total_num_microbatches)

liangjing's avatar
v1  
liangjing committed
570
571
572
573
574
575
576
577
578
579
    assert isinstance(model, list), "interleaved pipeline parallelism expected model chunking"
    assert all(isinstance(chunk, torch.nn.Module) for chunk in model), "invalid model chunking"
    assert isinstance(
        data_iterator, list
    ), "interleaved pipeline parallelism expected each model chunk to have a data iterator"

    config = get_model_config(model[0])
    if config.overlap_p2p_comm and config.batch_p2p_comm:
        raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")

xingjinliang's avatar
xingjinliang committed
580
581
582
583
584
585
586
    # Needed only when gradients are finalized in M-Core
    if config.finalize_model_grads_func is not None and not forward_only:
        embedding_module = clear_embedding_activation_buffer(config, model)

    if config.timers is not None:
        config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)

liangjing's avatar
v1  
liangjing committed
587
588
    # Disable async grad reductions
    no_sync_func = config.no_sync_func
xingjinliang's avatar
xingjinliang committed
589
    if isinstance(no_sync_func, list):
liangjing's avatar
v1  
liangjing committed
590
591
592

        def multi_no_sync():
            stack = contextlib.ExitStack()
xingjinliang's avatar
xingjinliang committed
593
594
            for model_chunk_no_sync_func in config.no_sync_func:
                stack.enter_context(model_chunk_no_sync_func())
liangjing's avatar
v1  
liangjing committed
595
596
597
598
599
600
601
            return stack

        no_sync_func = multi_no_sync
    if no_sync_func is None:
        no_sync_func = contextlib.nullcontext
    no_sync_context = None

xingjinliang's avatar
xingjinliang committed
602
603
604
605
606
607
608
609
610
611
612
613
614
    if config.grad_sync_func is not None and not isinstance(config.grad_sync_func, list):
        config.grad_sync_func = [config.grad_sync_func for _ in model]

    if config.param_sync_func is not None and not isinstance(config.param_sync_func, list):
        config.param_sync_func = [config.param_sync_func for _ in model]

    # Disable config.grad_sync_func and config.param_sync_func if only running forward passes.
    # They will be re-enabled at the end of this function.
    grad_sync_func, param_sync_func = None, None
    if forward_only:
        grad_sync_func, param_sync_func = config.grad_sync_func, config.param_sync_func
        config.grad_sync_func, config.param_sync_func = None, None

liangjing's avatar
v1  
liangjing committed
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
    def disable_grad_sync():
        """Disable asynchronous grad reductions"""
        nonlocal no_sync_context
        if no_sync_context is None:
            no_sync_context = no_sync_func()
            no_sync_context.__enter__()

    def enable_grad_sync():
        """Enable asynchronous grad reductions"""
        nonlocal no_sync_context
        if no_sync_context is not None:
            no_sync_context.__exit__(None, None, None)
            no_sync_context = None

    disable_grad_sync()

    # Model chunk IDs with synchronized grads
    synchronized_model_chunks = set()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
633

634
635
    input_tensors = [[] for _ in range(len(model))]
    output_tensors = [[] for _ in range(len(model))]
xingjinliang's avatar
xingjinliang committed
636
637
    total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()

638
    forward_data_store = []
639
640
641
    if not forward_only:
        output_tensor_grads = [[] for _ in range(len(model))]

642
643
644
    pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
    pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()

xingjinliang's avatar
xingjinliang committed
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
    if (
        config.microbatch_group_size_per_vp_stage > num_microbatches
        or config.microbatch_group_size_per_vp_stage < pipeline_parallel_size
    ):
        msg = (
            'The number of contiguous micro-batches in a virtual pipeline stage'
            f'should range in [PP={pipeline_parallel_size} , M={num_microbatches}]'
        )
        raise ValueError(msg)

    # If the final micro-batch group has fewer micro-batches than pipeline-parallel size,
    # the pipeline will have dependency bubbles.
    final_microbatch_group_size = num_microbatches % config.microbatch_group_size_per_vp_stage
    if 0 < final_microbatch_group_size < pipeline_parallel_size:
        msg = 'The remainder of M (the total micro-batches) divided by N (number of '
        msg += 'contiguous micro-batches in a virtual pipeline stage) should be 0, '
        msg += 'or larger than or equal to the pipeline-parallel size, but it is '
        msg += f'{final_microbatch_group_size}. '
        msg += 'Otherwise, it introduces dependency bubbles in the pipeline '
        msg += 'and reduces throughput.'
665
666
667
668
669
670
        raise RuntimeError(msg)

    model_type = get_model_type(model[0])
    if model_type == ModelType.encoder_and_decoder:
        raise RuntimeError("Interleaving is not supported with an encoder and decoder model.")

liangjing's avatar
v1  
liangjing committed
671
672
673
    if decoder_seq_length is not None and decoder_seq_length != seq_length:
        raise RuntimeError(
            "Interleaving is not supported with a different decoder sequence length."
674
        )
675

liangjing's avatar
v1  
liangjing committed
676
    tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
xingjinliang's avatar
xingjinliang committed
677
    tensor_shape[0] = tensor_shape[0] // parallel_state.get_context_parallel_world_size()
liangjing's avatar
v1  
liangjing committed
678
679
680
    if config.sequence_parallel:
        tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()

681
682
    # Compute number of warmup and remaining microbatches.
    num_model_chunks = len(model)
683
    total_num_microbatches = num_microbatches * num_model_chunks
684
685
    all_warmup_microbatches = False
    if forward_only:
686
        num_warmup_microbatches = total_num_microbatches
687
    else:
xingjinliang's avatar
xingjinliang committed
688
        # Run (num_model_chunks-1)*config.microbatch_group_size_per_vp_stage on
689
690
691
        # all workers, followed by more microbatches after depending on
        # stage ID (more forward passes for earlier stages, later stages can
        # immediately start with 1F1B).
xingjinliang's avatar
xingjinliang committed
692
693
694
695
696
        num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
        num_warmup_microbatches += (
            num_model_chunks - 1
        ) * config.microbatch_group_size_per_vp_stage
        if num_warmup_microbatches >= total_num_microbatches:
697
            num_warmup_microbatches = total_num_microbatches
698
            all_warmup_microbatches = True
liangjing's avatar
v1  
liangjing committed
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
    num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches

    # Checkpoint the activations of partial Transformer layers in a number of micro-batches
    # within the maximum outstanding micro-batch backpropagations.
    # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
    # checkpoint partial Transformer layers (or skip checkpointing) and
    # the rest of micro-batches within a window of micro-batches checkpoint
    # all Transformer layers. The window of micro-batches is set by the maximum
    # outstanding backpropagations and becomes smaller at later pipeline stages.
    # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
    max_outstanding_backprops = None
    if config.num_microbatches_with_partial_activation_checkpoints is not None:
        max_outstanding_backprops = num_warmup_microbatches + 1

    # Synchronize params for first two model chunks
    if config.param_sync_func is not None:
xingjinliang's avatar
xingjinliang committed
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
        config.param_sync_func[0](model[0].parameters())
        config.param_sync_func[1](model[1].parameters())

    # Create a tunable schedule lookup table.
    # The schedule lookup table uses the virtual_microbatch_id to find the corresponding
    # microbatch_id and model_chunk_id. For example, the tunable schedule table for
    # PP2 N3M5 with VP2 is constructed as below:
    # virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
    # microbatch_id         | 0 1 2 0 1 2 3 4 3 4
    # model_chunk_id        | 0 0 0 1 1 1 0 0 1 1
    schedule_table = []
    for min_microbatch_id_in_group in range(
        0, num_microbatches, config.microbatch_group_size_per_vp_stage
    ):
        if (
            min_microbatch_id_in_group + config.microbatch_group_size_per_vp_stage
            >= num_microbatches
        ):
            # Construct schedule for the last microbatch group
            schedule_table.extend(
                [
                    (microbatch_id, model_chunk_id)
                    for model_chunk_id in range(len(model))
                    for microbatch_id in range(min_microbatch_id_in_group, num_microbatches)
                ]
            )
        else:
            # Construct schedule for other microbatch groups
            schedule_table.extend(
                [
                    (microbatch_id, model_chunk_id)
                    for model_chunk_id in range(len(model))
                    for microbatch_id in range(
                        min_microbatch_id_in_group,
                        min_microbatch_id_in_group + config.microbatch_group_size_per_vp_stage,
                    )
                ]
            )
753

xingjinliang's avatar
xingjinliang committed
754
755
756
757
758
759
760
761
762
763
764
    # Decouple individual lookup table for microbatch_id and model_chunk_id.
    # For example, the micro-batch table for PP2 N3M5 with VP2 is
    # virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
    # microbatch_id         | 0 1 2 0 1 2 3 4 3 4
    # Similarly, the model chunk table is
    # virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
    # model_chunk_id        | 0 0 0 1 1 1 0 0 1 1
    # Both tables are indexed with virtual_microbatch_id.
    microbatch_id_table, model_chunk_id_table = zip(*schedule_table)

    def get_model_chunk_id(virtual_microbatch_id, forward):
765
        """Helper method to get the model chunk ID given the iteration number."""
xingjinliang's avatar
xingjinliang committed
766
        model_chunk_id = model_chunk_id_table[virtual_microbatch_id % total_num_microbatches]
767
        if not forward:
liangjing's avatar
v1  
liangjing committed
768
            model_chunk_id = num_model_chunks - model_chunk_id - 1
769
        return model_chunk_id
770

xingjinliang's avatar
xingjinliang committed
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
    def get_microbatch_id_in_model_chunk(iteration_id, forward):
        """Helper method to get the microbatch_id within model chunk given the iteration number."""
        assert forward
        microbatch_id_in_model_chunk = microbatch_id_table[iteration_id]
        return microbatch_id_in_model_chunk

    def num_released_microbatches(virtual_microbatch_id, model_chunk_id):
        """Helper method to count number of released (i.e. popped from input_tensors)
        microbatches for a model chunk."""
        if forward_only:  # Micro-batch is released after forward prop.
            return model_chunk_id_table[:virtual_microbatch_id].count(model_chunk_id)
        else:  # Micro-batch is released after backward prop.
            # Zero backward prop in warmup.
            if virtual_microbatch_id < num_warmup_microbatches:
                return 0
            else:
                backward_microbatch_id = virtual_microbatch_id - num_warmup_microbatches
                model_chunk_id = num_model_chunks - model_chunk_id - 1
                return model_chunk_id_table[:backward_microbatch_id].count(model_chunk_id)

    def is_first_microbatch_for_model_chunk(virtual_microbatch_id: int) -> bool:
liangjing's avatar
v1  
liangjing committed
792
        """Check if an iteration is the first for a model chunk."""
xingjinliang's avatar
xingjinliang committed
793
794
        if virtual_microbatch_id < total_num_microbatches:
            return microbatch_id_table[virtual_microbatch_id] == 0
liangjing's avatar
v1  
liangjing committed
795
796
797
        else:
            return False

xingjinliang's avatar
xingjinliang committed
798
    def is_last_microbatch_for_model_chunk(virtual_microbatch_id: int) -> bool:
liangjing's avatar
v1  
liangjing committed
799
        """Check if an iteration is the last for a model chunk."""
xingjinliang's avatar
xingjinliang committed
800
801
        if virtual_microbatch_id < total_num_microbatches:
            return microbatch_id_table[virtual_microbatch_id] == num_microbatches - 1
liangjing's avatar
v1  
liangjing committed
802
803
804
        else:
            return False

xingjinliang's avatar
xingjinliang committed
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
    def recv_tensor_from_previous_stage(virtual_microbatch_id, forward):
        """Determine if peers are sending, and where in data structure
        to put received tensors.
        Return a boolean if the pipeline stage expects to recv from peers, and the
        corresponding model_chunk_id for the received tensor.
        """
        recv = True
        # The leading pipeline stage is the first rank in fwd and the last rank in bwd.
        is_leading_pipeline_stage = (
            parallel_state.is_pipeline_first_stage(ignore_virtual=True)
            if forward
            else parallel_state.is_pipeline_last_stage(ignore_virtual=True)
        )

        last_model_chunk = (num_model_chunks - 1) if forward else 0

        if is_leading_pipeline_stage:
            # The leading pipeline stage is ahead of the ending pipeline stage
            # (i.e. last rank in fwd and first rank in bwd) by (pipeline_parallel_size - 1).
            # Let's consider bwd as an example with PP 4:
            #       0 1 2 3 ...
            #     0 1 2 3 ...
            #   0 1 2 3 ...
            # 0 1 2 3 ...
            if virtual_microbatch_id < (pipeline_parallel_size - 1):
                # The ending stage has not produced any tensors, so no recv will be initiated.
                recv = False
                next_model_chunk_id = get_model_chunk_id(virtual_microbatch_id + 1, forward)
            else:
                # Find the model chunk of the aligned microbatches in the ending stage.
                # For example, microbatch 0 in the ending stage is aligned with microbatch 3
                # in the leading stage.
                next_model_chunk_id = get_model_chunk_id(
                    virtual_microbatch_id - (pipeline_parallel_size - 1), forward
                )
            # Last model chunk in the final stage does not produce tensors.
            if next_model_chunk_id == last_model_chunk:
                recv = False
            if forward:
                # Model chunk id increases in forward.
                next_model_chunk_id += 1
            else:
                # Model chunk id decreases in backward.
                next_model_chunk_id -= 1
        else:
            next_model_chunk_id = get_model_chunk_id(virtual_microbatch_id + 1, forward)

        return recv, next_model_chunk_id

    def forward_step_helper(
        virtual_microbatch_id, microbatch_id, checkpoint_activations_microbatch
    ):
857
858
859
        """Helper method to run forward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        forward_step())."""
xingjinliang's avatar
xingjinliang committed
860
        model_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=True)
861
        parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
862

liangjing's avatar
v1  
liangjing committed
863
864
865
866
867
868
        # launch param synchronization for next model chunk
        # Note: Asynchronous communication tends to slow down compute.
        # To reduce idling from mismatched microbatch times, we launch
        # asynchronous communication at the same time across the
        # pipeline-parallel group.
        if config.param_sync_func is not None:
xingjinliang's avatar
xingjinliang committed
869
            param_sync_virtual_microbatch_id = virtual_microbatch_id + pipeline_parallel_rank
liangjing's avatar
v1  
liangjing committed
870
            if (
xingjinliang's avatar
xingjinliang committed
871
872
                param_sync_virtual_microbatch_id < total_num_microbatches
                and is_first_microbatch_for_model_chunk(param_sync_virtual_microbatch_id)
liangjing's avatar
v1  
liangjing committed
873
            ):
xingjinliang's avatar
xingjinliang committed
874
875
876
                param_sync_chunk_id = (
                    get_model_chunk_id(param_sync_virtual_microbatch_id, forward=True) + 1
                )
liangjing's avatar
v1  
liangjing committed
877
                if 1 < param_sync_chunk_id < num_model_chunks:
xingjinliang's avatar
xingjinliang committed
878
879
880
                    config.param_sync_func[param_sync_chunk_id](
                        model[param_sync_chunk_id].parameters()
                    )
liangjing's avatar
v1  
liangjing committed
881

882
        # forward step
883
        if parallel_state.is_pipeline_first_stage():
liangjing's avatar
v1  
liangjing committed
884
            if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
885
                input_tensors[model_chunk_id].append(None)
xingjinliang's avatar
xingjinliang committed
886
887
888
889
890
891
892
893
894
895

        # For non-depth-first pipeline schedules, the first rank would buffer multiple received
        # activation tensors for a model chunk until accessed during warmup.
        # This input buffering is needed to overlap the computation with the receipt of
        # the next inputs. To index the proper buffered inputs for forword_step, we use
        # microbatch_id offset with number of released microbatches that have completed backprop.
        offset = num_released_microbatches(virtual_microbatch_id, model_chunk_id)
        input_tensor = input_tensors[model_chunk_id][microbatch_id - offset]

        output_tensor, num_tokens = forward_step(
liangjing's avatar
v1  
liangjing committed
896
897
898
899
900
901
902
903
904
            forward_step_func,
            data_iterator[model_chunk_id],
            model[model_chunk_id],
            num_microbatches,
            input_tensor,
            forward_data_store,
            config,
            collect_non_loss_data,
            checkpoint_activations_microbatch,
xingjinliang's avatar
xingjinliang committed
905
906
907
908
909
910
            check_first_val_step(
                first_val_step,
                forward_only,
                is_first_microbatch_for_model_chunk(virtual_microbatch_id),
            ),
            current_microbatch=microbatch_id,
liangjing's avatar
v1  
liangjing committed
911
        )
xingjinliang's avatar
xingjinliang committed
912

913
914
        output_tensors[model_chunk_id].append(output_tensor)

xingjinliang's avatar
xingjinliang committed
915
916
917
918
        nonlocal total_num_tokens
        total_num_tokens += num_tokens.item()

        # If forward-only, no need to save tensors for a backward pass.
919
        if forward_only:
xingjinliang's avatar
xingjinliang committed
920
921
            # Release the tensor that have completed forward step.
            input_tensors[model_chunk_id].pop(0)
922
923
            output_tensors[model_chunk_id].pop()

924
925
        return output_tensor

xingjinliang's avatar
xingjinliang committed
926
    def backward_step_helper(virtual_microbatch_id):
927
928
929
        """Helper method to run backward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        backward_step())."""
xingjinliang's avatar
xingjinliang committed
930
        model_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=False)
931
        parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
932

liangjing's avatar
v1  
liangjing committed
933
        # launch grad synchronization (default)
xingjinliang's avatar
xingjinliang committed
934
935
936
        if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(
            virtual_microbatch_id
        ):
liangjing's avatar
v1  
liangjing committed
937
938
939
            enable_grad_sync()
            synchronized_model_chunks.add(model_chunk_id)

940
        if parallel_state.is_pipeline_last_stage():
941
942
943
944
945
            if len(output_tensor_grads[model_chunk_id]) == 0:
                output_tensor_grads[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id].pop(0)
        output_tensor = output_tensors[model_chunk_id].pop(0)
        output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
xingjinliang's avatar
xingjinliang committed
946

liangjing's avatar
v1  
liangjing committed
947
948
949
950
951
952
953
954
955
956
        input_tensor_grad = backward_step(
            input_tensor, output_tensor, output_tensor_grad, model_type, config
        )

        # launch grad synchronization (custom grad sync)
        # Note: Asynchronous communication tends to slow down compute.
        # To reduce idling from mismatched microbatch times, we launch
        # asynchronous communication at the same time across the
        # pipeline-parallel group.
        if config.grad_sync_func is not None:
xingjinliang's avatar
xingjinliang committed
957
958
959
            grad_sync_virtual_microbatch_id = virtual_microbatch_id - pipeline_parallel_rank
            if grad_sync_virtual_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
                grad_sync_virtual_microbatch_id
liangjing's avatar
v1  
liangjing committed
960
            ):
xingjinliang's avatar
xingjinliang committed
961
962
963
                grad_sync_chunk_id = get_model_chunk_id(
                    grad_sync_virtual_microbatch_id, forward=False
                )
liangjing's avatar
v1  
liangjing committed
964
                enable_grad_sync()
xingjinliang's avatar
xingjinliang committed
965
                config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters())
liangjing's avatar
v1  
liangjing committed
966
967
                synchronized_model_chunks.add(grad_sync_chunk_id)
        disable_grad_sync()
968
969
970
971

        return input_tensor_grad

    # Run warmup forward passes.
972
    parallel_state.set_virtual_pipeline_model_parallel_rank(0)
liangjing's avatar
v1  
liangjing committed
973
974
975
    input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config))

    fwd_wait_handles = None
xingjinliang's avatar
xingjinliang committed
976
    fwd_wait_recv_handles = None
liangjing's avatar
v1  
liangjing committed
977
    bwd_wait_handles = None
xingjinliang's avatar
xingjinliang committed
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
    bwd_wait_recv_handles = None
    if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
        fwd_recv_buffer_size = (
            config.microbatch_group_size_per_vp_stage - pipeline_parallel_size + 1
        )
    else:
        fwd_recv_buffer_size = 1
    if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
        bwd_recv_buffer_size = (
            config.microbatch_group_size_per_vp_stage - pipeline_parallel_size + 1
        )
    else:
        bwd_recv_buffer_size = 1
    fwd_recv_buffer = [None] * fwd_recv_buffer_size
    bwd_recv_buffer = [None] * bwd_recv_buffer_size
    recv_prev_wait_handles = []
    send_next_wait_handle = None
    send_prev_wait_handle = None
    recv_next_wait_handles = []
liangjing's avatar
v1  
liangjing committed
997

998
    for k in range(num_warmup_microbatches):
xingjinliang's avatar
xingjinliang committed
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
        cur_model_chunk_id = get_model_chunk_id(k, forward=True)
        parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)

        if config.overlap_p2p_comm_warmup_flush:
            if not parallel_state.is_pipeline_first_stage() and k != 0:
                assert recv_prev_wait_handles, (
                    f'pp rank {pipeline_parallel_rank}, iteration {k},'
                    'should have registered recv handle'
                )
                recv_prev_wait_handle = recv_prev_wait_handles.pop(0)
                recv_prev_wait_handle.wait()
liangjing's avatar
v1  
liangjing committed
1010

xingjinliang's avatar
xingjinliang committed
1011
1012
        # Determine if tensor should be received from previous stage.
        recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage(k, forward=True)
liangjing's avatar
v1  
liangjing committed
1013

xingjinliang's avatar
xingjinliang committed
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
        # No receive in last iteration when recv iteration k+1.
        if k == (total_num_microbatches - 1):
            recv_prev = False

        # Prefetch recv for iteration k+1 for non-first ranks.
        if config.overlap_p2p_comm_warmup_flush and not parallel_state.is_pipeline_first_stage(
            ignore_virtual=True
        ):
            fwd_recv_buffer[k % fwd_recv_buffer_size], fwd_wait_recv_handles = (
                p2p_communication.send_forward_recv_forward(
                    output_tensor=None,  # No output_tensor to send.
                    recv_prev=recv_prev,
                    tensor_shape=tensor_shape,
                    config=config,
                    overlap_p2p_comm=True,
                )
            )

            if fwd_wait_recv_handles:
                recv_prev_wait_handles.append(fwd_wait_recv_handles.pop("recv_prev"))

        # Decide to checkpoint all layers' activations of the current micro-batch.
liangjing's avatar
v1  
liangjing committed
1036
1037
1038
1039
1040
1041
1042
1043
        if max_outstanding_backprops is not None:
            checkpoint_activations_microbatch = (
                k % max_outstanding_backprops
                >= config.num_microbatches_with_partial_activation_checkpoints
            )
        else:
            checkpoint_activations_microbatch = None

xingjinliang's avatar
xingjinliang committed
1044
1045
        microbatch_id = get_microbatch_id_in_model_chunk(k, forward=True)
        output_tensor = forward_step_helper(k, microbatch_id, checkpoint_activations_microbatch)
1046
1047

        # Don't send tensor downstream if on last stage.
1048
        if parallel_state.is_pipeline_last_stage():
1049
            output_tensor = None
1050
1051
1052

        # Send and receive tensors as appropriate (send tensors computed
        # in this iteration; receive tensors for next iteration).
xingjinliang's avatar
xingjinliang committed
1053
        if not config.overlap_p2p_comm_warmup_flush:
liangjing's avatar
v1  
liangjing committed
1054
1055
            if (
                k == (num_warmup_microbatches - 1)
xingjinliang's avatar
xingjinliang committed
1056
                and not config.overlap_p2p_comm
liangjing's avatar
v1  
liangjing committed
1057
1058
1059
1060
1061
1062
1063
                and not forward_only
                and not all_warmup_microbatches
            ):
                input_tensor_grad = None
                recv_next = True
                if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False
xingjinliang's avatar
xingjinliang committed
1064
1065
1066
1067
1068
1069
1070
1071
1072
                (input_tensor, output_tensor_grad) = (
                    p2p_communication.send_forward_backward_recv_forward_backward(
                        output_tensor,
                        input_tensor_grad,
                        recv_prev=recv_prev,
                        recv_next=recv_next,
                        tensor_shape=tensor_shape,
                        config=config,
                    )
liangjing's avatar
v1  
liangjing committed
1073
1074
1075
1076
1077
1078
                )
                output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
            else:
                input_tensor = p2p_communication.send_forward_recv_forward(
                    output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, config=config
                )
xingjinliang's avatar
xingjinliang committed
1079
1080
1081
            if recv_prev:
                input_tensors[next_forward_model_chunk_id].append(input_tensor)
            deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
1082
        else:
xingjinliang's avatar
xingjinliang committed
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
            if not parallel_state.is_pipeline_first_stage(ignore_virtual=True):
                # Send only since recv prefetched.
                _, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
                    output_tensor,
                    recv_prev=False,
                    tensor_shape=tensor_shape,
                    config=config,
                    overlap_p2p_comm=True,
                )
            else:  # No prefetch for first rank, so both send and recv initiated.
                fwd_recv_buffer[k % fwd_recv_buffer_size], fwd_wait_handles = (
                    p2p_communication.send_forward_recv_forward(
                        output_tensor,
                        recv_prev=recv_prev,
                        tensor_shape=tensor_shape,
                        config=config,
                        overlap_p2p_comm=True,
                    )
                )
            if send_next_wait_handle is not None:
                send_next_wait_handle.wait()
            if fwd_wait_handles is not None:
                send_next_wait_handle = (
                    fwd_wait_handles.pop("send_next") if "send_next" in fwd_wait_handles else None
                )
                if "recv_prev" in fwd_wait_handles:
                    recv_prev_wait_handles.append(fwd_wait_handles.pop("recv_prev"))

            deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
            if recv_prev:
                input_tensors[next_forward_model_chunk_id].append(
                    fwd_recv_buffer[k % fwd_recv_buffer_size]
                )
                fwd_recv_buffer[(k + 1) % fwd_recv_buffer_size] = None
liangjing's avatar
v1  
liangjing committed
1117

xingjinliang's avatar
xingjinliang committed
1118
        if config.overlap_p2p_comm:
liangjing's avatar
v1  
liangjing committed
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
            if (
                k == (num_warmup_microbatches - 1)
                and not forward_only
                and not all_warmup_microbatches
            ):
                input_tensor_grad = None
                recv_next = True
                if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

xingjinliang's avatar
xingjinliang committed
1129
1130
1131
1132
1133
1134
1135
1136
                (bwd_recv_buffer[-1], bwd_wait_handles) = (
                    p2p_communication.send_backward_recv_backward(
                        input_tensor_grad,
                        recv_next=recv_next,
                        tensor_shape=tensor_shape,
                        config=config,
                        overlap_p2p_comm=True,
                    )
liangjing's avatar
v1  
liangjing committed
1137
                )
xingjinliang's avatar
xingjinliang committed
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
                if send_prev_wait_handle is not None:
                    send_prev_wait_handle.wait()
                if bwd_wait_handles is not None:
                    send_prev_wait_handle = (
                        bwd_wait_handles.pop("send_prev")
                        if "send_prev" in bwd_wait_handles
                        else None
                    )
                    if "recv_next" in bwd_wait_handles:
                        recv_next_wait_handles.append(bwd_wait_handles.pop("recv_next"))

                if recv_next:
                    output_tensor_grads[num_model_chunks - 1].append(bwd_recv_buffer[-1])
1151
1152
1153
1154
1155
1156

    # Run 1F1B in steady state.
    for k in range(num_microbatches_remaining):
        # Forward pass.
        forward_k = k + num_warmup_microbatches

xingjinliang's avatar
xingjinliang committed
1157
        # Decide to checkpoint all layers' activations of the current micro-batch.
liangjing's avatar
v1  
liangjing committed
1158
1159
1160
1161
1162
1163
1164
1165
        if max_outstanding_backprops is not None:
            checkpoint_activations_microbatch = (
                forward_k % max_outstanding_backprops
                >= config.num_microbatches_with_partial_activation_checkpoints
            )
        else:
            checkpoint_activations_microbatch = None

xingjinliang's avatar
xingjinliang committed
1166
1167
1168
        cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
        parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
        microbatch_id = get_microbatch_id_in_model_chunk(forward_k, forward=True)
liangjing's avatar
v1  
liangjing committed
1169
        if config.overlap_p2p_comm:
xingjinliang's avatar
xingjinliang committed
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
            if not parallel_state.is_pipeline_first_stage():
                if config.overlap_p2p_comm_warmup_flush:
                    assert recv_prev_wait_handles, (
                        f'pp rank {pipeline_parallel_rank}, fwd iteration {forward_k}, '
                        'should have registered recv handle'
                    )
                    recv_prev_wait_handle = recv_prev_wait_handles.pop(0)
                    recv_prev_wait_handle.wait()
                else:
                    if recv_prev_wait_handles is not None and recv_prev_wait_handles:
                        recv_prev_wait_handle = recv_prev_wait_handles.pop(0)
                        recv_prev_wait_handle.wait()
liangjing's avatar
v1  
liangjing committed
1182
1183
1184

            deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)

xingjinliang's avatar
xingjinliang committed
1185
1186
1187
            output_tensor = forward_step_helper(
                forward_k, microbatch_id, checkpoint_activations_microbatch
            )
liangjing's avatar
v1  
liangjing committed
1188
1189
1190
1191
1192
1193

            # Determine if current stage has anything to send in either direction,
            # otherwise set tensor to None.
            forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
            parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)

xingjinliang's avatar
xingjinliang committed
1194
            # Last virtual stage no activation tensor to send.
liangjing's avatar
v1  
liangjing committed
1195
1196
1197
            if parallel_state.is_pipeline_last_stage():
                output_tensor = None

xingjinliang's avatar
xingjinliang committed
1198
1199
1200
            recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage(
                forward_k, forward=True
            )
1201

liangjing's avatar
v1  
liangjing committed
1202
1203
1204
1205
            # If last iteration, don't receive; we already received one extra
            # before the start of the for loop.
            if k == (num_microbatches_remaining - 1):
                recv_prev = False
1206

liangjing's avatar
v1  
liangjing committed
1207
1208
            # Send activation tensor to the next stage and receive activation tensor from the
            # previous stage
xingjinliang's avatar
xingjinliang committed
1209
1210
1211
1212
1213
1214
1215
1216
            fwd_recv_buffer[forward_k % fwd_recv_buffer_size], fwd_wait_handles = (
                p2p_communication.send_forward_recv_forward(
                    output_tensor,
                    recv_prev=recv_prev,
                    tensor_shape=tensor_shape,
                    config=config,
                    overlap_p2p_comm=True,
                )
liangjing's avatar
v1  
liangjing committed
1217
            )
xingjinliang's avatar
xingjinliang committed
1218
1219
1220
1221
1222
1223
1224
1225
            if send_next_wait_handle is not None:
                send_next_wait_handle.wait()
            if fwd_wait_handles is not None:
                send_next_wait_handle = (
                    fwd_wait_handles.pop("send_next") if "send_next" in fwd_wait_handles else None
                )
                if "recv_prev" in fwd_wait_handles:
                    recv_prev_wait_handles.append(fwd_wait_handles.pop("recv_prev"))
liangjing's avatar
v1  
liangjing committed
1226
1227
1228
1229
1230
1231
            # assert fwd_wait_handles is not None

            # Backward pass.
            backward_k = k
            backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
            parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
xingjinliang's avatar
xingjinliang committed
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
            if not parallel_state.is_pipeline_last_stage():
                if config.overlap_p2p_comm_warmup_flush:
                    assert recv_next_wait_handles, (
                        f'pp rank {pipeline_parallel_rank}, bwd iteration {backward_k}, '
                        'should have registered recv next handle'
                    )
                    recv_next_wait_handle = recv_next_wait_handles.pop(0)
                    recv_next_wait_handle.wait()
                else:
                    if recv_next_wait_handles is not None and recv_next_wait_handles:
                        recv_next_wait_handle = recv_next_wait_handles.pop(0)
                        recv_next_wait_handle.wait()

            input_tensor_grad = backward_step_helper(backward_k)
liangjing's avatar
v1  
liangjing committed
1246

xingjinliang's avatar
xingjinliang committed
1247
            # First virtual stage no activation gradient tensor to send.
liangjing's avatar
v1  
liangjing committed
1248
1249
1250
            if parallel_state.is_pipeline_first_stage():
                input_tensor_grad = None

xingjinliang's avatar
xingjinliang committed
1251
1252
            recv_next, next_backward_model_chunk_id = recv_tensor_from_previous_stage(
                backward_k, forward=False
liangjing's avatar
v1  
liangjing committed
1253
1254
            )

xingjinliang's avatar
xingjinliang committed
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
            (bwd_recv_buffer[backward_k % bwd_recv_buffer_size], bwd_wait_handles) = (
                p2p_communication.send_backward_recv_backward(
                    input_tensor_grad,
                    recv_next=recv_next,
                    tensor_shape=tensor_shape,
                    config=config,
                    overlap_p2p_comm=True,
                )
            )
            if send_prev_wait_handle is not None:
                send_prev_wait_handle.wait()
            if bwd_wait_handles is not None:
                send_prev_wait_handle = (
                    bwd_wait_handles.pop("send_prev") if "send_prev" in bwd_wait_handles else None
                )
                if "recv_next" in bwd_wait_handles:
                    recv_next_wait_handles.append(bwd_wait_handles.pop("recv_next"))

            # Put input_tensor and output_tensor_grad in data structures in the
            # right location.
            if recv_prev:
                input_tensors[next_forward_model_chunk_id].append(
                    fwd_recv_buffer[forward_k % fwd_recv_buffer_size]
                )
                fwd_recv_buffer[(forward_k + 1) % fwd_recv_buffer_size] = None
            if recv_next:
                output_tensor_grads[next_backward_model_chunk_id].append(
                    bwd_recv_buffer[backward_k % bwd_recv_buffer_size]
                )
                bwd_recv_buffer[(backward_k + 1) % bwd_recv_buffer_size] = None
        else:  # No p2p overlap.
            output_tensor = forward_step_helper(
                forward_k, microbatch_id, checkpoint_activations_microbatch
            )
liangjing's avatar
v1  
liangjing committed
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308

            # Backward pass.
            backward_k = k
            input_tensor_grad = backward_step_helper(backward_k)

            # Send output_tensor and input_tensor_grad, receive input_tensor
            # and output_tensor_grad.

            # Determine if current stage has anything to send in either direction,
            # otherwise set tensor to None.
            forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
            parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
            if parallel_state.is_pipeline_last_stage():
                output_tensor = None

            backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
            parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
            if parallel_state.is_pipeline_first_stage():
                input_tensor_grad = None

xingjinliang's avatar
xingjinliang committed
1309
1310
1311
            recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage(
                forward_k, forward=True
            )
1312

xingjinliang's avatar
xingjinliang committed
1313
1314
1315
            recv_next, next_backward_model_chunk_id = recv_tensor_from_previous_stage(
                backward_k, forward=False
            )
1316

liangjing's avatar
v1  
liangjing committed
1317
1318
1319
            # If last iteration, don't receive; we already received one extra
            # before the start of the for loop.
            if k == (num_microbatches_remaining - 1):
1320
1321
                recv_prev = False

liangjing's avatar
v1  
liangjing committed
1322
            # Communicate tensors.
xingjinliang's avatar
xingjinliang committed
1323
1324
1325
1326
1327
1328
1329
1330
1331
            (input_tensor, output_tensor_grad) = (
                p2p_communication.send_forward_backward_recv_forward_backward(
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
                    recv_next=recv_next,
                    tensor_shape=tensor_shape,
                    config=config,
                )
liangjing's avatar
v1  
liangjing committed
1332
1333
            )
            deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
1334

xingjinliang's avatar
xingjinliang committed
1335
1336
1337
1338
1339
1340
            # Put input_tensor and output_tensor_grad in data structures in the
            # right location.
            if recv_prev:
                input_tensors[next_forward_model_chunk_id].append(input_tensor)
            if recv_next:
                output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
liangjing's avatar
v1  
liangjing committed
1341
1342

    deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
1343

1344
    # Run cooldown backward passes (flush out pipeline).
1345
    if not forward_only:
xingjinliang's avatar
xingjinliang committed
1346
1347
1348
        if bwd_wait_handles is not None:
            for bwd_wait_handle in bwd_wait_handles.values():
                bwd_wait_handle.wait()
liangjing's avatar
v1  
liangjing committed
1349

1350
        if all_warmup_microbatches:
liangjing's avatar
v1  
liangjing committed
1351
1352
1353
            output_tensor_grads[num_model_chunks - 1].append(
                p2p_communication.recv_backward(tensor_shape, config=config)
            )
1354
        for k in range(num_microbatches_remaining, total_num_microbatches):
xingjinliang's avatar
xingjinliang committed
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
            cur_model_chunk_id = get_model_chunk_id(k, forward=False)
            parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
            if not parallel_state.is_pipeline_last_stage() and k != 0:
                if config.overlap_p2p_comm_warmup_flush:
                    assert recv_next_wait_handles, (
                        f'pp rank {pipeline_parallel_rank}, backward iteration {k}, '
                        'should have registered recv next handle'
                    )
                    recv_next_wait_handle = recv_next_wait_handles.pop(0)
                    recv_next_wait_handle.wait()
                else:
                    if recv_next_wait_handles is not None and recv_next_wait_handles:
                        recv_next_wait_handle = recv_next_wait_handles.pop(0)
                        recv_next_wait_handle.wait()

            recv_next, next_backward_model_chunk_id = recv_tensor_from_previous_stage(
                k, forward=False
            )

1374
            if k == (total_num_microbatches - 1):
1375
                recv_next = False
xingjinliang's avatar
xingjinliang committed
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437

            # Prefetch recv for backward iteration k+1 for non last ranks.
            if config.overlap_p2p_comm_warmup_flush and not parallel_state.is_pipeline_last_stage(
                ignore_virtual=True
            ):
                bwd_recv_buffer[k % bwd_recv_buffer_size], bwd_wait_recv_handles = (
                    p2p_communication.send_backward_recv_backward(
                        input_tensor_grad=None,  # No input_tensor_grad to send.
                        recv_next=recv_next,
                        tensor_shape=tensor_shape,
                        config=config,
                        overlap_p2p_comm=True,
                    )
                )

                if bwd_wait_recv_handles:
                    recv_next_wait_handles.append(bwd_wait_recv_handles.pop("recv_next"))

            input_tensor_grad = backward_step_helper(k)

            # First virtual stage no activation gradient tensor to send.
            if parallel_state.is_pipeline_first_stage():
                input_tensor_grad = None

            if config.overlap_p2p_comm_warmup_flush:
                if not parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                    _, bwd_wait_handles = p2p_communication.send_backward_recv_backward(
                        input_tensor_grad,
                        recv_next=False,
                        tensor_shape=tensor_shape,
                        config=config,
                        overlap_p2p_comm=True,
                    )
                else:
                    bwd_recv_buffer[k % bwd_recv_buffer_size], bwd_wait_handles = (
                        p2p_communication.send_backward_recv_backward(
                            input_tensor_grad,
                            recv_next=recv_next,
                            tensor_shape=tensor_shape,
                            config=config,
                            overlap_p2p_comm=True,
                        )
                    )

                if send_prev_wait_handle is not None:
                    send_prev_wait_handle.wait()
                if bwd_wait_handles is not None:
                    send_prev_wait_handle = (
                        bwd_wait_handles.pop("send_prev")
                        if "send_prev" in bwd_wait_handles
                        else None
                    )
                    if "recv_next" in bwd_wait_handles:
                        recv_next_wait_handles.append(bwd_wait_handles.pop("recv_next"))
                if recv_next:
                    output_tensor_grads[next_backward_model_chunk_id].append(
                        bwd_recv_buffer[k % bwd_recv_buffer_size]
                    )
                    bwd_recv_buffer[(k + 1) % bwd_recv_buffer_size] = None

            else:
                output_tensor_grad = p2p_communication.send_backward_recv_backward(
liangjing's avatar
v1  
liangjing committed
1438
1439
1440
                    input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, config=config
                )

xingjinliang's avatar
xingjinliang committed
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
                if recv_next:
                    output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)

        if send_prev_wait_handle is not None:
            send_prev_wait_handle.wait()

        # Launch any remaining grad reductions.
        enable_grad_sync()
        if config.grad_sync_func is not None:
            for model_chunk_id in range(num_model_chunks):
                if model_chunk_id not in synchronized_model_chunks:
                    config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
                    synchronized_model_chunks.add(model_chunk_id)

    assert (
        not recv_prev_wait_handles
    ), 'recv_prev_wait_handles should be cleared at the end of a step'
    assert (
        not recv_next_wait_handles
    ), 'recv_next_wait_handles should be cleared at the end of a step'

    if config.finalize_model_grads_func is not None and not forward_only:

        # If defer_embedding_wgrad_compute is enabled we need to do the
        # weight gradient GEMM's here.
        finish_embedding_wgrad_compute(config, embedding_module)

        # Finalize model grads (perform full grad all-reduce / reduce-scatter for
        # data parallelism, layernorm all-reduce for sequence parallelism, and
        # embedding all-reduce for pipeline parallelism).
        config.finalize_model_grads_func(
            model, total_num_tokens if config.calculate_per_token_loss else None
        )

    # Restore config.grad_sync_func and config.param_sync_func.
    if forward_only:
        config.grad_sync_func, config.param_sync_func = grad_sync_func, param_sync_func

    if config.timers is not None:
        config.timers('forward-backward').stop()
1481

1482
    return forward_data_store
1483

liangjing's avatar
v1  
liangjing committed
1484
1485
1486
1487
1488
1489
1490
1491
1492

def get_tensor_shapes(
    *,
    rank: int,
    model_type: ModelType,
    seq_length: int,
    micro_batch_size: int,
    decoder_seq_length: int,
    config,
xingjinliang's avatar
xingjinliang committed
1493
    encoder_decoder_xattn: bool,
liangjing's avatar
v1  
liangjing committed
1494
):
xingjinliang's avatar
xingjinliang committed
1495
1496
1497
1498
1499
1500
1501
1502
1503
    """
    Determine right tensor sizes (based on position of rank with respect to split rank) and
    model size.
    Send two tensors if model decoder requires the encoder's output (via cross-attention) and
    rank is in decoder stage.
    First tensor is decoder. Second tensor is encoder.
    If model has an encoder & decoder and rank is at the boundary, send one tensor.
    Otherwise, send one tensor.
    """
1504
    tensor_shapes = []
1505

xingjinliang's avatar
xingjinliang committed
1506
1507
1508
1509
    seq_length = seq_length // parallel_state.get_context_parallel_world_size()
    if model_type == ModelType.encoder_and_decoder:
        decoder_seq_length = decoder_seq_length // parallel_state.get_context_parallel_world_size()

liangjing's avatar
v1  
liangjing committed
1510
    if config.sequence_parallel:
1511
        seq_length = seq_length // parallel_state.get_tensor_model_parallel_world_size()
liangjing's avatar
v1  
liangjing committed
1512
1513
1514
1515
        if model_type == ModelType.encoder_and_decoder:
            decoder_seq_length = (
                decoder_seq_length // parallel_state.get_tensor_model_parallel_world_size()
            )
1516
1517

    if model_type == ModelType.encoder_and_decoder:
xingjinliang's avatar
xingjinliang committed
1518
        if parallel_state.is_inside_encoder(rank) and not parallel_state.is_inside_decoder(rank):
liangjing's avatar
v1  
liangjing committed
1519
            tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
xingjinliang's avatar
xingjinliang committed
1520
        elif encoder_decoder_xattn:
liangjing's avatar
v1  
liangjing committed
1521
1522
            tensor_shapes.append((decoder_seq_length, micro_batch_size, config.hidden_size))
            tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
xingjinliang's avatar
xingjinliang committed
1523
1524
1525
        else:
            tensor_shapes.append((decoder_seq_length, micro_batch_size, config.hidden_size))
    else:  # model_type == ModelType.encoder_or_decoder
liangjing's avatar
v1  
liangjing committed
1526
        tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
1527
1528
1529
    return tensor_shapes


liangjing's avatar
v1  
liangjing committed
1530
def recv_forward(tensor_shapes, config):
xingjinliang's avatar
xingjinliang committed
1531
    """Wrapper for p2p_communication.recv_forward used with non-interleaving schedule."""
1532
1533
1534
1535
1536
    input_tensors = []
    for tensor_shape in tensor_shapes:
        if tensor_shape is None:
            input_tensors.append(None)
        else:
liangjing's avatar
v1  
liangjing committed
1537
            input_tensors.append(p2p_communication.recv_forward(tensor_shape, config))
1538
1539
1540
    return input_tensors


liangjing's avatar
v1  
liangjing committed
1541
def recv_backward(tensor_shapes, config):
xingjinliang's avatar
xingjinliang committed
1542
    """Wrapper for p2p_communication.recv_backward used with non-interleaving schedule."""
1543
1544
1545
1546
1547
    output_tensor_grads = []
    for tensor_shape in tensor_shapes:
        if tensor_shape is None:
            output_tensor_grads.append(None)
        else:
liangjing's avatar
v1  
liangjing committed
1548
            output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, config))
1549
1550
1551
    return output_tensor_grads


liangjing's avatar
v1  
liangjing committed
1552
def send_forward(output_tensors, tensor_shapes, config):
xingjinliang's avatar
xingjinliang committed
1553
    """Wrapper for p2p_communication.send_forward used with non-interleaving schedule."""
1554
1555
    if not isinstance(output_tensors, list):
        output_tensors = [output_tensors]
xingjinliang's avatar
xingjinliang committed
1556
    for output_tensor, tensor_shape in zip(output_tensors, tensor_shapes):
1557
1558
        if tensor_shape is None:
            continue
liangjing's avatar
v1  
liangjing committed
1559
        p2p_communication.send_forward(output_tensor, config)
1560
1561


liangjing's avatar
v1  
liangjing committed
1562
def send_backward(input_tensor_grads, tensor_shapes, config):
xingjinliang's avatar
xingjinliang committed
1563
    """Wrapper for p2p_communication.send_backward used with non-interleaving schedule."""
1564
1565
    if not isinstance(input_tensor_grads, list):
        input_tensor_grads = [input_tensor_grads]
xingjinliang's avatar
xingjinliang committed
1566
    for input_tensor_grad, tensor_shape in zip(input_tensor_grads, tensor_shapes):
1567
1568
        if tensor_shape is None:
            continue
liangjing's avatar
v1  
liangjing committed
1569
        p2p_communication.send_backward(input_tensor_grad, config)
1570
1571


liangjing's avatar
v1  
liangjing committed
1572
def send_forward_recv_backward(output_tensors, tensor_shapes, config):
xingjinliang's avatar
xingjinliang committed
1573
1574
    """Wrapper for p2p_communication.send_forward_recv_backward used
    with non-interleaving schedule."""
1575
1576
1577
    if not isinstance(output_tensors, list):
        output_tensors = [output_tensors]
    output_tensor_grads = []
xingjinliang's avatar
xingjinliang committed
1578
    for output_tensor, tensor_shape in zip(output_tensors, tensor_shapes):
1579
1580
1581
1582
        if tensor_shape is None:
            output_tensor_grads.append(None)
            continue
        output_tensor_grad = p2p_communication.send_forward_recv_backward(
liangjing's avatar
v1  
liangjing committed
1583
1584
            output_tensor, tensor_shape, config
        )
1585
1586
1587
1588
        output_tensor_grads.append(output_tensor_grad)
    return output_tensor_grads


liangjing's avatar
v1  
liangjing committed
1589
def send_backward_recv_forward(input_tensor_grads, tensor_shapes, config):
xingjinliang's avatar
xingjinliang committed
1590
1591
    """Wrapper for p2p_communication.send_backward_recv_forward used
    with non-interleaving schedule."""
1592
1593
1594
    if not isinstance(input_tensor_grads, list):
        input_tensor_grads = [input_tensor_grads]
    input_tensors = []
xingjinliang's avatar
xingjinliang committed
1595
    for input_tensor_grad, tensor_shape in zip(input_tensor_grads, tensor_shapes):
1596
1597
1598
1599
        if tensor_shape is None:
            input_tensors.append(None)
            continue
        input_tensor = p2p_communication.send_backward_recv_forward(
liangjing's avatar
v1  
liangjing committed
1600
1601
            input_tensor_grad, tensor_shape, config
        )
1602
1603
1604
1605
        input_tensors.append(input_tensor)
    return input_tensors


liangjing's avatar
v1  
liangjing committed
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
def forward_backward_pipelining_without_interleaving(
    *,
    forward_step_func,
    data_iterator: Union[Iterator, List[Iterator]],
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    num_microbatches: int,
    seq_length: int,
    micro_batch_size: int,
    decoder_seq_length: int = None,
    forward_only: bool = False,
    collect_non_loss_data: bool = False,
xingjinliang's avatar
xingjinliang committed
1617
    first_val_step: bool = None,
liangjing's avatar
v1  
liangjing committed
1618
):
1619
    """Run non-interleaved 1F1B schedule, with communication between pipeline
xingjinliang's avatar
xingjinliang committed
1620
    stages. Returns dictionary with losses if the last stage, empty dict otherwise."""
1621

liangjing's avatar
v1  
liangjing committed
1622
1623
1624
    if isinstance(model, list):
        assert (
            len(model) == 1
xingjinliang's avatar
xingjinliang committed
1625
        ), "non-interleaved pipeline-parallel schedule does not support model chunking"
liangjing's avatar
v1  
liangjing committed
1626
1627
1628
1629
        model = model[0]
    if isinstance(data_iterator, list):
        assert (
            len(data_iterator) == 1
xingjinliang's avatar
xingjinliang committed
1630
        ), "non-interleaved pipeline-parallel schedule does not support model chunking"
liangjing's avatar
v1  
liangjing committed
1631
1632
1633
1634
1635
1636
1637
1638
        data_iterator = data_iterator[0]

    config = get_model_config(model)
    if config.overlap_p2p_comm:
        raise ValueError(
            "Non-interleaved pipeline parallelism does not support overlapping p2p communication"
        )

xingjinliang's avatar
xingjinliang committed
1639
1640
1641
1642
1643
1644
1645
    # Needed only when gradients are finalized in M-Core
    if config.finalize_model_grads_func is not None and not forward_only:
        embedding_module = clear_embedding_activation_buffer(config, model)

    if config.timers is not None:
        config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)

liangjing's avatar
v1  
liangjing committed
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
    # Disable async grad reductions
    no_sync_func = config.no_sync_func
    if no_sync_func is None:
        no_sync_func = contextlib.nullcontext
    no_sync_context = None

    def disable_grad_sync():
        """Disable asynchronous grad reductions"""
        nonlocal no_sync_context
        if no_sync_context is None:
            no_sync_context = no_sync_func()
            no_sync_context.__enter__()

    def enable_grad_sync():
        """Enable asynchronous grad reductions"""
        nonlocal no_sync_context
        if no_sync_context is not None:
            no_sync_context.__exit__(None, None, None)
            no_sync_context = None

    disable_grad_sync()
1667
1668

    # Compute number of warmup microbatches.
liangjing's avatar
v1  
liangjing committed
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
    num_warmup_microbatches = (
        parallel_state.get_pipeline_model_parallel_world_size()
        - parallel_state.get_pipeline_model_parallel_rank()
        - 1
    )
    num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
    num_microbatches_remaining = num_microbatches - num_warmup_microbatches

    # Checkpoint the activations of partial Transformer layers in a number of micro-batches
    # within the maximum outstanding micro-batch backpropagations.
    # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
    # checkpoint partial Transformer layers (or skip checkpointing) and
    # the rest of micro-batches within a window of micro-batches checkpoint
    # all Transformer layers. The window of micro-batches is set by the maximum
    # outstanding backpropagations and becomes smaller at later pipeline stages.
    # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
    max_outstanding_backprops = None
    if config.num_microbatches_with_partial_activation_checkpoints is not None:
        max_outstanding_backprops = num_warmup_microbatches + 1
1688

1689
    model_type = get_model_type(model)
xingjinliang's avatar
xingjinliang committed
1690
    encoder_decoder_xattn = get_model_xattn(model)
1691
1692

    rank = parallel_state.get_pipeline_model_parallel_rank()
liangjing's avatar
v1  
liangjing committed
1693
1694
1695
1696
1697
1698
1699
    recv_tensor_shapes = get_tensor_shapes(
        rank=rank - 1,
        model_type=model_type,
        seq_length=seq_length,
        micro_batch_size=micro_batch_size,
        decoder_seq_length=decoder_seq_length,
        config=config,
xingjinliang's avatar
xingjinliang committed
1700
        encoder_decoder_xattn=encoder_decoder_xattn,
liangjing's avatar
v1  
liangjing committed
1701
1702
1703
1704
1705
1706
1707
1708
    )
    send_tensor_shapes = get_tensor_shapes(
        rank=rank,
        model_type=model_type,
        seq_length=seq_length,
        micro_batch_size=micro_batch_size,
        decoder_seq_length=decoder_seq_length,
        config=config,
xingjinliang's avatar
xingjinliang committed
1709
        encoder_decoder_xattn=encoder_decoder_xattn,
liangjing's avatar
v1  
liangjing committed
1710
    )
1711

1712
1713
1714
    # Input, output tensors only need to be saved when doing backward passes
    input_tensors = None
    output_tensors = None
xingjinliang's avatar
xingjinliang committed
1715
1716
    total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()

1717
1718
1719
    if not forward_only:
        input_tensors = []
        output_tensors = []
1720
    forward_data_store = []
1721
1722
1723

    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
liangjing's avatar
v1  
liangjing committed
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
        # Decide to checkpoint all layers' activations of the current micro-batch
        if max_outstanding_backprops is not None:
            checkpoint_activations_microbatch = (
                i % max_outstanding_backprops
                >= config.num_microbatches_with_partial_activation_checkpoints
            )
        else:
            checkpoint_activations_microbatch = None

        input_tensor = recv_forward(recv_tensor_shapes, config)
xingjinliang's avatar
xingjinliang committed
1734
        output_tensor, num_tokens = forward_step(
liangjing's avatar
v1  
liangjing committed
1735
1736
1737
1738
1739
1740
1741
1742
1743
            forward_step_func,
            data_iterator,
            model,
            num_microbatches,
            input_tensor,
            forward_data_store,
            config,
            collect_non_loss_data,
            checkpoint_activations_microbatch,
xingjinliang's avatar
xingjinliang committed
1744
1745
1746
            check_first_val_step(first_val_step, forward_only, i == 0),
            current_microbatch=i,
            encoder_decoder_xattn=encoder_decoder_xattn,
liangjing's avatar
v1  
liangjing committed
1747
1748
        )
        send_forward(output_tensor, send_tensor_shapes, config)
xingjinliang's avatar
xingjinliang committed
1749
        total_num_tokens += num_tokens.item()
1750

1751
1752
1753
        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
liangjing's avatar
v1  
liangjing committed
1754
            deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
1755
1756
1757
1758
1759

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
liangjing's avatar
v1  
liangjing committed
1760
        input_tensor = recv_forward(recv_tensor_shapes, config)
1761
1762
1763

    # Run 1F1B in steady state.
    for i in range(num_microbatches_remaining):
liangjing's avatar
v1  
liangjing committed
1764
        last_iteration = i == (num_microbatches_remaining - 1)
1765

liangjing's avatar
v1  
liangjing committed
1766
1767
1768
1769
1770
1771
1772
1773
        # Decide to checkpoint all layers' activations of the current micro-batch
        if max_outstanding_backprops is not None:
            checkpoint_activations_microbatch = (
                (i + num_warmup_microbatches) % max_outstanding_backprops
            ) >= config.num_microbatches_with_partial_activation_checkpoints
        else:
            checkpoint_activations_microbatch = None

xingjinliang's avatar
xingjinliang committed
1774
        output_tensor, num_tokens = forward_step(
liangjing's avatar
v1  
liangjing committed
1775
1776
1777
1778
1779
1780
1781
1782
1783
            forward_step_func,
            data_iterator,
            model,
            num_microbatches,
            input_tensor,
            forward_data_store,
            config,
            collect_non_loss_data,
            checkpoint_activations_microbatch,
xingjinliang's avatar
xingjinliang committed
1784
1785
1786
1787
1788
            check_first_val_step(
                first_val_step, forward_only, (i == 0) and (num_warmup_microbatches == 0)
            ),
            current_microbatch=i + num_warmup_microbatches,
            encoder_decoder_xattn=encoder_decoder_xattn,
liangjing's avatar
v1  
liangjing committed
1789
        )
xingjinliang's avatar
xingjinliang committed
1790
        total_num_tokens += num_tokens.item()
1791

1792
        if forward_only:
liangjing's avatar
v1  
liangjing committed
1793
            send_forward(output_tensor, send_tensor_shapes, config)
1794
1795

            if not last_iteration:
liangjing's avatar
v1  
liangjing committed
1796
                input_tensor = recv_forward(recv_tensor_shapes, config)
1797

1798
        else:
liangjing's avatar
v1  
liangjing committed
1799
1800
1801
            output_tensor_grad = send_forward_recv_backward(
                output_tensor, send_tensor_shapes, config
            )
1802

1803
1804
1805
            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
liangjing's avatar
v1  
liangjing committed
1806
            deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
1807

1808
1809
1810
1811
            # Pop input_tensor and output_tensor from the start of the list for
            # the backward pass.
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)
1812

xingjinliang's avatar
xingjinliang committed
1813
1814
1815
1816
1817
1818
            # Enable grad sync for the last microbatch in the batch if the full
            # backward pass completes in the 1F1B stage.
            if num_warmup_microbatches == 0 and last_iteration:
                if config.grad_sync_func is None or rank == 0:
                    enable_grad_sync()

liangjing's avatar
v1  
liangjing committed
1819
1820
1821
            input_tensor_grad = backward_step(
                input_tensor, output_tensor, output_tensor_grad, model_type, config
            )
1822
1823
1824

            if last_iteration:
                input_tensor = None
liangjing's avatar
v1  
liangjing committed
1825
                send_backward(input_tensor_grad, recv_tensor_shapes, config)
1826
            else:
liangjing's avatar
v1  
liangjing committed
1827
1828
1829
                input_tensor = send_backward_recv_forward(
                    input_tensor_grad, recv_tensor_shapes, config
                )
1830
1831
1832
1833

    # Run cooldown backward passes.
    if not forward_only:
        for i in range(num_warmup_microbatches):
liangjing's avatar
v1  
liangjing committed
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843

            # Enable async grad reduction in the last backward pass
            # Note: If grad sync function is provided, only enable
            # async grad reduction in first pipeline stage. Other
            # pipeline stages do grad reduction during pipeline
            # bubble.
            if i == num_warmup_microbatches - 1:
                if config.grad_sync_func is None or rank == 0:
                    enable_grad_sync()

1844
1845
1846
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

liangjing's avatar
v1  
liangjing committed
1847
1848
1849
1850
1851
            output_tensor_grad = recv_backward(send_tensor_shapes, config)

            input_tensor_grad = backward_step(
                input_tensor, output_tensor, output_tensor_grad, model_type, config
            )
1852

liangjing's avatar
v1  
liangjing committed
1853
            send_backward(input_tensor_grad, recv_tensor_shapes, config)
1854

xingjinliang's avatar
xingjinliang committed
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
        # Launch any remaining grad reductions.
        if no_sync_context is not None:
            enable_grad_sync()
            if config.grad_sync_func is not None:
                config.grad_sync_func(model.parameters())

    if config.finalize_model_grads_func is not None and not forward_only:

        # If defer_embedding_wgrad_compute is enabled we need to do the
        # weight gradient GEMM's here.
        finish_embedding_wgrad_compute(config, embedding_module)

        # Finalize model grads (perform full grad all-reduce / reduce-scatter for
        # data parallelism, layernorm all-reduce for sequence parallelism, and
        # embedding all-reduce for pipeline parallelism).
        config.finalize_model_grads_func(
            [model], total_num_tokens if config.calculate_per_token_loss else None
        )

    if config.timers is not None:
        config.timers('forward-backward').stop()
1876

1877
    return forward_data_store