schedules.py 63.4 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2

liangjing's avatar
v1  
liangjing committed
3
import contextlib
liangjing's avatar
liangjing committed
4
from typing import Iterator, List, Union
5

6
import torch
7
from torch.autograd.variable import Variable
8

9
10
from megatron.core import parallel_state
from megatron.core.enums import ModelType
liangjing's avatar
v1  
liangjing committed
11
from megatron.core.pipeline_parallel import p2p_communication
liangjing's avatar
liangjing committed
12
13
14
15
16
17
18
19
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.utils import (
    drain_embedding_wgrad_compute,
    get_attr_wrapped_model,
    get_model_config,
    get_model_type,
    get_model_xattn,
)
20

21
22
# Types
Shape = Union[List[int], torch.Size]
23

liangjing's avatar
v1  
liangjing committed
24

Jared Casper's avatar
Jared Casper committed
25
def get_forward_backward_func():
26
27
28
29
30
31
32
33
    """Retrieves the appropriate forward_backward function given the
    configuration of parallel_state.

    Returns a function that will perform all of the forward and
    backward passes of the model given the pipeline model parallel
    world size and virtual pipeline model parallel world size in the
    global parallel_state.

liangjing's avatar
v1  
liangjing committed
34
35
36
37
    Note that if using sequence parallelism, the sequence length component of
    the tensor shape is updated to original_sequence_length /
    tensor_model_parallel_world_size.

38
39
40
41
42
43
44
45
    The function returned takes the following arguments:

    forward_step_func (required): A function that takes a data
        iterator and a model as its arguments and return the model's
        forward output and the loss function. The loss function should
        take one torch.Tensor and return a torch.Tensor of loss and a
        dictionary of string -> torch.Tensor.

liangjing's avatar
v1  
liangjing committed
46
47
48
49
50
51
52
        A third argument, checkpoint_activations_microbatch, indicates
        that the activations for this microbatch should be
        checkpointed. A None value for this argument indicates that
        the default from the configuration should be used. This is
        used when the
        num_microbatches_with_partial_activation_checkpoints is used.

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        For example:

        def loss_func(loss_mask, output_tensor):
            losses = output_tensor.float()
            loss_mask = loss_mask.view(-1).float()
            loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

            # Reduce loss for logging.
            averaged_loss = average_losses_across_data_parallel_group([loss])

            return loss, {'lm loss': averaged_loss[0]}

        def forward_step(data_iterator, model):
            data, loss_mask = next(data_iterator)
            output = model(data)
            return output, partial(loss_func, loss_mask)


        forward_backward_func(forward_step_func=forward_step, ...)


    data_iterator (required): an iterator over the data, will be
liangjing's avatar
v1  
liangjing committed
75
76
        passed as is to forward_step_func. Expected to be a list of
        iterators in the case of interleaved pipeline parallelism.
77

liangjing's avatar
v1  
liangjing committed
78
79
    model (required): the actual model. Expected to be a list of modules in the case of interleaved
        pipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule.
80
81
82
83

    num_microbatches (int, required):
        The number of microbatches to go through

liangjing's avatar
v1  
liangjing committed
84
85
86
87
    seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack
        transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths
        in the config is True. Otherwise, each microbatch in the current global batch size must use
        this sequence length.
88

liangjing's avatar
v1  
liangjing committed
89
    micro_batch_size (int, required): The number of sequences in a microbatch.
90

liangjing's avatar
v1  
liangjing committed
91
92
    decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stack
        transformer. This is ignored for a single-stack transformer.
93

liangjing's avatar
v1  
liangjing committed
94
    forward_only (optional, default = False): Perform only the forward step
95

liangjing's avatar
v1  
liangjing committed
96
    collect_non_loss_data (optional, bool, default=False): TODO
Abhinav Khattar's avatar
Abhinav Khattar committed
97

liangjing's avatar
liangjing committed
98
99
100
101
    first_val_step (bool, optional): Is the first step of the validation phase. Used by
        Transformer Engine modules to only update their fp8 weights only on the first validation
        step.

102
103
104
105
    """
    pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
    if pipeline_model_parallel_size > 1:
        if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
Jared Casper's avatar
Jared Casper committed
106
107
108
109
110
111
112
            forward_backward_func = forward_backward_pipelining_with_interleaving
        else:
            forward_backward_func = forward_backward_pipelining_without_interleaving
    else:
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func

liangjing's avatar
v1  
liangjing committed
113
114

def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
115
    '''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
116
117
118
119
120

    This method should be called right after the output tensor has been
    sent to the next pipeline stage. At this point, the output tensor is
    only useful for its '.grad_fn' field, and not its '.data'.
    '''
liangjing's avatar
v1  
liangjing committed
121
    if (out is None) or (not deallocate_pipeline_outputs):
Lawrence McAfee's avatar
Lawrence McAfee committed
122
        return
liangjing's avatar
v1  
liangjing committed
123
124
    assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
    assert out._base is None, "counter-productive to free a view of another tensor."
liangjing's avatar
liangjing committed
125
    out.data = torch.empty((1,), device=out.device, dtype=out.dtype)
liangjing's avatar
v1  
liangjing committed
126

127

128
def custom_backward(output, grad_output):
129
130
    '''Directly call C++ autograd engine.

131
    To make the 'deallocate_output_tensor' (above) optimization work, the C++
132
133
134
135
    autograd engine must be called directly, bypassing Pytorch's
    torch.autograd.backward. Pytorch's 'backward' checks that the output and
    grad have the same shape, while C++'s 'backward' does not.
    '''
136

liangjing's avatar
v1  
liangjing committed
137
138
139
    assert output.numel() == 1, "output should be pseudo-'freed' in schedule, to optimize memory"
    assert isinstance(output, torch.Tensor), "output == '%s'." % type(output).__name__
    assert isinstance(grad_output, (torch.Tensor, type(None))), (
140
        "grad_output == '%s'." % type(grad_output).__name__
liangjing's avatar
v1  
liangjing committed
141
    )
142
143
144
145

    # Handle scalar output
    if grad_output is None:
        assert output.numel() == 1, "implicit grad requires scalar output."
liangjing's avatar
liangjing committed
146
        grad_output = torch.ones_like(output, memory_format=torch.preserve_format)
147
148

    # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Lawrence McAfee's avatar
Lawrence McAfee committed
149
    Variable._execution_engine.run_backward(
liangjing's avatar
v1  
liangjing committed
150
151
152
153
154
        tensors=(output,),
        grad_tensors=(grad_output,),
        keep_graph=False,
        create_graph=False,
        inputs=tuple(),
Lawrence McAfee's avatar
Lawrence McAfee committed
155
156
157
        allow_unreachable=True,
        accumulate_grad=True,
    )
158
159


liangjing's avatar
liangjing committed
160
161
162
163
164
165
166
167
168
169
170
def set_current_microbatch(model, microbatch_id):
    decoder_exists = True
    decoder = None
    try:
        decoder = get_attr_wrapped_model(model, "decoder")
    except RuntimeError:
        decoder_exists = False
    if decoder_exists and decoder is not None:
        decoder.current_microbatch = microbatch_id


liangjing's avatar
v1  
liangjing committed
171
172
173
174
175
176
177
178
179
180
def forward_step(
    forward_step_func,
    data_iterator,
    model,
    num_microbatches,
    input_tensor,
    forward_data_store,
    config,
    collect_non_loss_data=False,
    checkpoint_activations_microbatch=None,
liangjing's avatar
liangjing committed
181
182
183
    is_first_microbatch=False,
    current_microbatch=None,
    encoder_decoder_xattn=False,
liangjing's avatar
v1  
liangjing committed
184
):
185
186
    """Forward step for passed-in model.

liangjing's avatar
liangjing committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    If it is the first stage, the input tensor is obtained from the data_iterator.
    Otherwise, the passed-in input_tensor is used.

    Args:
        forward_step_func (callable):
            The forward step function for the model that takes the
            data iterator as the first argument, and model as the second.
            This user's forward step is expected to output a tuple of two elements:

                1. The output object from the forward step. This output object needs to be a
                    tensor or some kind of collection of tensors. The only hard requirement
                    for this object is that it needs to be acceptible as input into the second
                    function.
                2. A function to reduce (optionally) the output from the forward step. This
                    could be a reduction over the loss from the model, it could be a function that
                    grabs the output from the model and reformats, it could be a function that just
                    passes through the model output. This function must have one of the following
                    patterns, and depending on the pattern different things happen internally:

                        a. A tuple of reduced loss and some other data. Note that in this case
                            the first argument is divided by the number of global microbatches,
                            assuming it is a loss, so that the loss is stable as a function of
                            the number of devices the step is split across.
                        b. A triple of reduced loss, number of tokens, and some other data. This
                            is similar to case (a), but the loss is further averaged across the
                            number of tokens in the batch. If the user is not already averaging
                            across the number of tokens, this pattern is useful to use.
                        c. Any arbitrary data the user wants (eg a dictionary of tensors, a list
                            of tensors, etc in the case of inference). To trigger case 3 you need
                            to specify `collect_non_loss_data=True` and you may also want to
                            specify `forward_only=True` in the call to the parent forward_backward
                            function.
        data_iterator (iterator):
            The data iterator.
        model (nn.Module):
            The model to perform the forward step on.
        num_microbatches (int):
            The number of microbatches.
        input_tensor (Tensor or list[Tensor]):
            The input tensor(s) for the forward step.
        forward_data_store (list):
            The list to store the forward data. If you go down path 2.a or
            2.b for the return of your forward reduction function then this will store only the
            final dimension of the output, for example the metadata output by the loss function.
            If you go down the path of 2.c then this will store the entire output of the forward
            reduction function applied to the model output.
        config (object):
            The configuration object.
        collect_non_loss_data (bool, optional):
            Whether to collect non-loss data. Defaults to False.
            This is the path to use if you want to collect arbitrary output from the model forward,
            such as with inference use cases. Defaults to False.
        checkpoint_activations_microbatch (int, optional):
            The microbatch to checkpoint activations.
            Defaults to None.
        is_first_microbatch (bool, optional):
            Whether it is the first microbatch. Defaults to False.
        current_microbatch (int, optional):
            The current microbatch. Defaults to None.

    Returns:
        Tensor or list[Tensor]: The output object(s) from the forward step.
        Tensor: The number of tokens.
    """
liangjing's avatar
v1  
liangjing committed
251
252
    if config.timers is not None:
        config.timers('forward-compute', log_level=2).start()
253

liangjing's avatar
liangjing committed
254
255
256
257
258
    if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):
        model.set_is_first_microbatch()
    if current_microbatch is not None:
        set_current_microbatch(model, current_microbatch)

259
260
261
262
263
    unwrap_output_tensor = False
    if not isinstance(input_tensor, list):
        input_tensor = [input_tensor]
        unwrap_output_tensor = True

264
265
266
    set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
    set_input_tensor(input_tensor)

liangjing's avatar
v1  
liangjing committed
267
268
269
270
    if config.enable_autocast:
        context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
    else:
        context_manager = contextlib.nullcontext()
271
    with context_manager:
liangjing's avatar
v1  
liangjing committed
272
273
274
275
276
277
        if checkpoint_activations_microbatch is None:
            output_tensor, loss_func = forward_step_func(data_iterator, model)
        else:
            output_tensor, loss_func = forward_step_func(
                data_iterator, model, checkpoint_activations_microbatch
            )
278

liangjing's avatar
liangjing committed
279
    num_tokens = torch.tensor(0, dtype=torch.int)
280
    if parallel_state.is_pipeline_last_stage():
281
        if not collect_non_loss_data:
liangjing's avatar
liangjing committed
282
283
284
285
286
287
288
289
290
291
292
            outputs = loss_func(output_tensor)
            if len(outputs) == 3:
                output_tensor, num_tokens, loss_reduced = outputs
                if not config.calculate_per_token_loss:
                    output_tensor /= num_tokens
                    output_tensor /= num_microbatches
            else:
                # preserve legacy loss averaging behavior (ie, over the number of microbatches)
                assert len(outputs) == 2
                output_tensor, loss_reduced = outputs
                output_tensor /= num_microbatches
293
294
295
296
297
            forward_data_store.append(loss_reduced)
        else:
            data = loss_func(output_tensor, non_loss_data=True)
            forward_data_store.append(data)

liangjing's avatar
v1  
liangjing committed
298
299
    if config.timers is not None:
        config.timers('forward-compute').stop()
300

liangjing's avatar
liangjing committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    # Set the loss scale for the auxiliary loss of the MoE layer.
    # Since we use a trick to do backward on the auxiliary loss, we need to set the scale
    # explicitly.
    if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
        # Calculate the loss scale based on the grad_scale_func if available, else default to 1.
        loss_scale = (
            config.grad_scale_func(torch.ones(1, device=output_tensor.device))
            if config.grad_scale_func is not None
            else torch.tensor(1.0)
        )
        # Set the loss scale
        MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)

    # If T5 model and in decoder stack, then send encoder_hidden_state
315
    # downstream as well.
316
    model_type = get_model_type(model)
liangjing's avatar
v1  
liangjing committed
317
    if (
liangjing's avatar
liangjing committed
318
319
320
        model_type == ModelType.encoder_and_decoder
        and encoder_decoder_xattn
        and parallel_state.is_inside_decoder()
liangjing's avatar
v1  
liangjing committed
321
    ):
liangjing's avatar
liangjing committed
322
323
        return [output_tensor, input_tensor[-1]], num_tokens

324
    if unwrap_output_tensor:
liangjing's avatar
liangjing committed
325
326
        return output_tensor, num_tokens
    return [output_tensor], num_tokens
327
328


liangjing's avatar
v1  
liangjing committed
329
def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config):
330
331
332
333
334
335
336
    """Backward step through passed-in output tensor.

    If last stage, output_tensor_grad is None, otherwise gradient of loss
    with respect to stage's output tensor.

    Returns gradient of loss with respect to input tensor (None if first
    stage)."""
337
338
339
340

    # NOTE: This code currently can handle at most one skip connection. It
    # needs to be modified slightly to support arbitrary numbers of skip
    # connections.
341

liangjing's avatar
v1  
liangjing committed
342
343
    if config.timers is not None:
        config.timers('backward-compute', log_level=2).start()
344
345

    # Retain the grad on the input_tensor.
346
347
348
349
350
351
352
353
354
355
356
357
    unwrap_input_tensor_grad = False
    if not isinstance(input_tensor, list):
        input_tensor = [input_tensor]
        unwrap_input_tensor_grad = True
    for x in input_tensor:
        if x is not None:
            x.retain_grad()

    if not isinstance(output_tensor, list):
        output_tensor = [output_tensor]
    if not isinstance(output_tensor_grad, list):
        output_tensor_grad = [output_tensor_grad]
358
359

    # Backward pass.
liangjing's avatar
v1  
liangjing committed
360
361
362
363
364
365
366
    if output_tensor_grad[0] is None and config.grad_scale_func is not None:
        output_tensor[0] = config.grad_scale_func(output_tensor[0])

    if config.deallocate_pipeline_outputs:
        custom_backward(output_tensor[0], output_tensor_grad[0])
    else:
        torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
367
368

    # Collect the grad of the input_tensor.
369
    input_tensor_grad = [None]
370
    if input_tensor is not None:
371
372
373
374
375
376
377
378
379
        input_tensor_grad = []
        for x in input_tensor:
            if x is None:
                input_tensor_grad.append(None)
            else:
                input_tensor_grad.append(x.grad)

    # Handle single skip connection if it exists (encoder_hidden_state in
    # model with encoder and decoder).
liangjing's avatar
v1  
liangjing committed
380
381
382
    if (
        parallel_state.get_pipeline_model_parallel_world_size() > 1
        and model_type == ModelType.encoder_and_decoder
liangjing's avatar
liangjing committed
383
        and len(output_tensor_grad) > 1  # excludes models that lack a skip connection.
liangjing's avatar
v1  
liangjing committed
384
    ):
385
        if output_tensor_grad[1] is not None:
liangjing's avatar
liangjing committed
386
            assert input_tensor_grad[-1] is not None
387
388
389
            input_tensor_grad[-1].add_(output_tensor_grad[1])
    if unwrap_input_tensor_grad:
        input_tensor_grad = input_tensor_grad[0]
390

liangjing's avatar
v1  
liangjing committed
391
392
    if config.timers is not None:
        config.timers('backward-compute').stop()
393
394
395
396

    return input_tensor_grad


liangjing's avatar
liangjing committed
397
398
399
400
401
402
403
def check_first_val_step(first_val_step, forward_only, cond):
    if (first_val_step is not None) and forward_only:
        return first_val_step and cond
    else:
        return cond


liangjing's avatar
v1  
liangjing committed
404
405
406
407
408
409
410
411
412
413
414
def forward_backward_no_pipelining(
    *,
    forward_step_func,
    data_iterator: Union[Iterator, List[Iterator]],
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    num_microbatches: int,
    seq_length: int,  # unused
    micro_batch_size: int,  # unused
    decoder_seq_length: int = None,  # unused
    forward_only: bool = False,
    collect_non_loss_data: bool = False,
liangjing's avatar
liangjing committed
415
    first_val_step: bool = None,
liangjing's avatar
v1  
liangjing committed
416
):
417
418
419
    """Run forward and backward passes with no pipeline parallelism
    (no inter-stage communication).

420
421
422
423
424
    Returns dictionary with losses.


    See get_forward_backward_func() for argument details
    """
425

liangjing's avatar
v1  
liangjing committed
426
427
428
429
430
431
432
433
434
435
    if isinstance(model, list):
        assert len(model) == 1, "non-pipeline-parallel schedule does not support model chunking"
        model = model[0]
    if isinstance(data_iterator, list):
        assert (
            len(data_iterator) == 1
        ), "non-pipeline-parallel schedule does not support model chunking"
        data_iterator = data_iterator[0]

    config = get_model_config(model)
liangjing's avatar
liangjing committed
436
437
    if config.timers is not None:
        config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
liangjing's avatar
v1  
liangjing committed
438
439
440
441

    no_sync_func = config.no_sync_func
    if no_sync_func is None:
        no_sync_func = contextlib.nullcontext
442

443
444
    model_type = get_model_type(model)

445
    forward_data_store = []
446
    input_tensor, output_tensor_grad = None, None
liangjing's avatar
liangjing committed
447
    total_num_tokens = torch.zeros([], dtype=torch.int, device="cuda")
liangjing's avatar
v1  
liangjing committed
448
    with no_sync_func():
449
        for i in range(num_microbatches - 1):
liangjing's avatar
liangjing committed
450
            output_tensor, num_tokens = forward_step(
liangjing's avatar
v1  
liangjing committed
451
452
453
454
455
456
457
458
                forward_step_func,
                data_iterator,
                model,
                num_microbatches,
                input_tensor,
                forward_data_store,
                config,
                collect_non_loss_data,
liangjing's avatar
liangjing committed
459
460
                is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0),
                current_microbatch=i,
liangjing's avatar
v1  
liangjing committed
461
            )
liangjing's avatar
liangjing committed
462
            total_num_tokens += num_tokens.item()
463
            if not forward_only:
liangjing's avatar
v1  
liangjing committed
464
                backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
465
466
467

    # Run computation for last microbatch out of context handler (want to
    # synchronize gradients).
liangjing's avatar
liangjing committed
468
    output_tensor, num_tokens = forward_step(
liangjing's avatar
v1  
liangjing committed
469
470
471
472
473
474
475
476
        forward_step_func,
        data_iterator,
        model,
        num_microbatches,
        input_tensor,
        forward_data_store,
        config,
        collect_non_loss_data,
liangjing's avatar
liangjing committed
477
478
479
480
        is_first_microbatch=check_first_val_step(
            first_val_step, forward_only, num_microbatches == 1
        ),
        current_microbatch=num_microbatches - 1,
liangjing's avatar
v1  
liangjing committed
481
    )
liangjing's avatar
liangjing committed
482
    total_num_tokens += num_tokens.item()
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
483

484
    if not forward_only:
liangjing's avatar
v1  
liangjing committed
485
        backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
486

liangjing's avatar
liangjing committed
487
488
489
490
491
492
493
494
495
496
    if config.finalize_model_grads_func is not None and not forward_only:
        # Finalize model grads (perform full grad all-reduce / reduce-scatter for
        # data parallelism and layernorm all-reduce for sequence parallelism).
        config.finalize_model_grads_func(
            [model], total_num_tokens if config.calculate_per_token_loss else None
        )

    if config.timers is not None:
        config.timers('forward-backward').stop()

497
    return forward_data_store
498
499


liangjing's avatar
liangjing committed
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
def clear_embedding_activation_buffer(config, model):

    if (
        parallel_state.is_pipeline_last_stage(ignore_virtual=True)
        and config.defer_embedding_wgrad_compute
    ):
        if isinstance(model, list):
            embedding_module = get_attr_wrapped_model(
                model[-1], 'post_process', return_model_obj=True
            )
        else:
            embedding_module = get_attr_wrapped_model(model, 'post_process', return_model_obj=True)

        # Need to ensure no stray activations exists in this buffer
        embedding_module.embedding_activation_buffer.clear()

        return embedding_module
    else:
        return None


def finish_embedding_wgrad_compute(config, embedding_module):
    if (
        parallel_state.is_pipeline_last_stage(ignore_virtual=True)
        and config.defer_embedding_wgrad_compute
    ):
        embedding_activation_buffer = embedding_module.embedding_activation_buffer
        grad_output_buffer = embedding_module.grad_output_buffer
        weight = (
            embedding_module.output_layer.weight
            if embedding_module.share_embeddings_and_output_weights
            else embedding_module.shared_embedding_or_output_weight()
        )

        drain_embedding_wgrad_compute(
            config, embedding_activation_buffer, grad_output_buffer, weight
        )


liangjing's avatar
v1  
liangjing committed
539
540
541
542
543
544
545
546
547
548
549
def forward_backward_pipelining_with_interleaving(
    *,
    forward_step_func,
    data_iterator: Union[Iterator, List[Iterator]],
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    num_microbatches: int,
    seq_length: int,
    micro_batch_size: int,
    decoder_seq_length: int = None,
    forward_only: bool = False,
    collect_non_loss_data: bool = False,
liangjing's avatar
liangjing committed
550
    first_val_step: bool = None,
liangjing's avatar
v1  
liangjing committed
551
):
552
553
554
555
    """Run interleaved 1F1B schedule (model split into model chunks), with
    communication between pipeline stages as needed.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
liangjing's avatar
v1  
liangjing committed
556
557
558
559
560
561
562
563
564
565
    assert isinstance(model, list), "interleaved pipeline parallelism expected model chunking"
    assert all(isinstance(chunk, torch.nn.Module) for chunk in model), "invalid model chunking"
    assert isinstance(
        data_iterator, list
    ), "interleaved pipeline parallelism expected each model chunk to have a data iterator"

    config = get_model_config(model[0])
    if config.overlap_p2p_comm and config.batch_p2p_comm:
        raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")

liangjing's avatar
liangjing committed
566
567
568
569
570
571
572
    # Needed only when gradients are finalized in M-Core
    if config.finalize_model_grads_func is not None and not forward_only:
        embedding_module = clear_embedding_activation_buffer(config, model)

    if config.timers is not None:
        config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)

liangjing's avatar
v1  
liangjing committed
573
574
    # Disable async grad reductions
    no_sync_func = config.no_sync_func
liangjing's avatar
liangjing committed
575
    if isinstance(no_sync_func, list):
liangjing's avatar
v1  
liangjing committed
576
577
578

        def multi_no_sync():
            stack = contextlib.ExitStack()
liangjing's avatar
liangjing committed
579
580
            for model_chunk_no_sync_func in config.no_sync_func:
                stack.enter_context(model_chunk_no_sync_func())
liangjing's avatar
v1  
liangjing committed
581
582
583
584
585
586
587
            return stack

        no_sync_func = multi_no_sync
    if no_sync_func is None:
        no_sync_func = contextlib.nullcontext
    no_sync_context = None

liangjing's avatar
liangjing committed
588
589
590
591
592
593
594
595
596
597
598
599
600
    if config.grad_sync_func is not None and not isinstance(config.grad_sync_func, list):
        config.grad_sync_func = [config.grad_sync_func for _ in model]

    if config.param_sync_func is not None and not isinstance(config.param_sync_func, list):
        config.param_sync_func = [config.param_sync_func for _ in model]

    # Disable config.grad_sync_func and config.param_sync_func if only running forward passes.
    # They will be re-enabled at the end of this function.
    grad_sync_func, param_sync_func = None, None
    if forward_only:
        grad_sync_func, param_sync_func = config.grad_sync_func, config.param_sync_func
        config.grad_sync_func, config.param_sync_func = None, None

liangjing's avatar
v1  
liangjing committed
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
    def disable_grad_sync():
        """Disable asynchronous grad reductions"""
        nonlocal no_sync_context
        if no_sync_context is None:
            no_sync_context = no_sync_func()
            no_sync_context.__enter__()

    def enable_grad_sync():
        """Enable asynchronous grad reductions"""
        nonlocal no_sync_context
        if no_sync_context is not None:
            no_sync_context.__exit__(None, None, None)
            no_sync_context = None

    disable_grad_sync()

    # Model chunk IDs with synchronized grads
    synchronized_model_chunks = set()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
619

620
621
    input_tensors = [[] for _ in range(len(model))]
    output_tensors = [[] for _ in range(len(model))]
liangjing's avatar
liangjing committed
622
623
    total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()

624
    forward_data_store = []
625
626
627
    if not forward_only:
        output_tensor_grads = [[] for _ in range(len(model))]

628
629
630
631
632
633
634
635
636
637
638
639
640
    pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
    pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()

    if num_microbatches % pipeline_parallel_size != 0:
        msg = f'number of microbatches ({num_microbatches}) is not divisible by '
        msg += f'pipeline-model-parallel-size ({pipeline_parallel_size}) '
        msg += 'when using interleaved schedule'
        raise RuntimeError(msg)

    model_type = get_model_type(model[0])
    if model_type == ModelType.encoder_and_decoder:
        raise RuntimeError("Interleaving is not supported with an encoder and decoder model.")

liangjing's avatar
v1  
liangjing committed
641
642
643
    if decoder_seq_length is not None and decoder_seq_length != seq_length:
        raise RuntimeError(
            "Interleaving is not supported with a different decoder sequence length."
644
        )
645

liangjing's avatar
v1  
liangjing committed
646
    tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
liangjing's avatar
liangjing committed
647
    tensor_shape[0] = tensor_shape[0] // parallel_state.get_context_parallel_world_size()
liangjing's avatar
v1  
liangjing committed
648
649
650
    if config.sequence_parallel:
        tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()

651
652
    # Compute number of warmup and remaining microbatches.
    num_model_chunks = len(model)
653
    total_num_microbatches = num_microbatches * num_model_chunks
654
655
    all_warmup_microbatches = False
    if forward_only:
656
        num_warmup_microbatches = total_num_microbatches
657
    else:
658
659
660
661
662
663
        # Run all forward passes and then all backward passes if number of
        # microbatches is just the number of pipeline stages.
        # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
        # all workers, followed by more microbatches after depending on
        # stage ID (more forward passes for earlier stages, later stages can
        # immediately start with 1F1B).
664
665
        if num_microbatches == pipeline_parallel_size:
            num_warmup_microbatches = total_num_microbatches
666
667
            all_warmup_microbatches = True
        else:
liangjing's avatar
v1  
liangjing committed
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
            num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
            num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
            num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches)
    num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches

    # Checkpoint the activations of partial Transformer layers in a number of micro-batches
    # within the maximum outstanding micro-batch backpropagations.
    # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
    # checkpoint partial Transformer layers (or skip checkpointing) and
    # the rest of micro-batches within a window of micro-batches checkpoint
    # all Transformer layers. The window of micro-batches is set by the maximum
    # outstanding backpropagations and becomes smaller at later pipeline stages.
    # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
    max_outstanding_backprops = None
    if config.num_microbatches_with_partial_activation_checkpoints is not None:
        max_outstanding_backprops = num_warmup_microbatches + 1

    # Synchronize params for first two model chunks
    if config.param_sync_func is not None:
liangjing's avatar
liangjing committed
687
688
        config.param_sync_func[0](model[0].parameters())
        config.param_sync_func[1](model[1].parameters())
689

690
    def get_model_chunk_id(microbatch_id, forward):
691
        """Helper method to get the model chunk ID given the iteration number."""
692
693
        microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
        model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
694
        if not forward:
liangjing's avatar
v1  
liangjing committed
695
            model_chunk_id = num_model_chunks - model_chunk_id - 1
696
        return model_chunk_id
697

liangjing's avatar
liangjing committed
698
699
700
701
702
703
704
705
706
    def get_microbatch_id_in_model_chunk(iteration_id, forward):
        """Helper method to get the microbatch_id within model chunk given the iteration number."""
        assert forward
        iteration_group_id = iteration_id // (pipeline_parallel_size * num_model_chunks)
        microbatch_id_in_model_chunk = (iteration_group_id * pipeline_parallel_size) + (
            iteration_id % pipeline_parallel_size
        )
        return microbatch_id_in_model_chunk

liangjing's avatar
v1  
liangjing committed
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
    def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool:
        """Check if an iteration is the first for a model chunk."""
        microbatch_group_size = pipeline_parallel_size * num_model_chunks
        microbatch_group_id = microbatch_id // microbatch_group_size
        microbatch_id_in_group = microbatch_id % microbatch_group_size
        if microbatch_group_id == 0:
            return microbatch_id_in_group % pipeline_parallel_size == 0
        else:
            return False

    def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool:
        """Check if an iteration is the last for a model chunk."""
        microbatch_group_size = pipeline_parallel_size * num_model_chunks
        num_microbatch_groups = total_num_microbatches // microbatch_group_size
        microbatch_group_id = microbatch_id // microbatch_group_size
        microbatch_id_in_group = microbatch_id % microbatch_group_size
        if microbatch_group_id == num_microbatch_groups - 1:
            return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1
        else:
            return False

liangjing's avatar
liangjing committed
728
    def forward_step_helper(microbatch_id, current_microbatch, checkpoint_activations_microbatch):
729
730
731
        """Helper method to run forward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        forward_step())."""
732
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
733
        parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
734

liangjing's avatar
v1  
liangjing committed
735
736
737
738
739
740
741
742
743
744
745
746
747
        # launch param synchronization for next model chunk
        # Note: Asynchronous communication tends to slow down compute.
        # To reduce idling from mismatched microbatch times, we launch
        # asynchronous communication at the same time across the
        # pipeline-parallel group.
        if config.param_sync_func is not None:
            param_sync_microbatch_id = microbatch_id + pipeline_parallel_rank
            if (
                param_sync_microbatch_id < total_num_microbatches
                and is_first_microbatch_for_model_chunk(param_sync_microbatch_id)
            ):
                param_sync_chunk_id = get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1
                if 1 < param_sync_chunk_id < num_model_chunks:
liangjing's avatar
liangjing committed
748
749
750
                    config.param_sync_func[param_sync_chunk_id](
                        model[param_sync_chunk_id].parameters()
                    )
liangjing's avatar
v1  
liangjing committed
751

752
        # forward step
753
        if parallel_state.is_pipeline_first_stage():
liangjing's avatar
v1  
liangjing committed
754
            if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
755
756
                input_tensors[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id][-1]
liangjing's avatar
liangjing committed
757
758

        output_tensor, num_tokens = forward_step(
liangjing's avatar
v1  
liangjing committed
759
760
761
762
763
764
765
766
767
            forward_step_func,
            data_iterator[model_chunk_id],
            model[model_chunk_id],
            num_microbatches,
            input_tensor,
            forward_data_store,
            config,
            collect_non_loss_data,
            checkpoint_activations_microbatch,
liangjing's avatar
liangjing committed
768
769
770
771
            check_first_val_step(
                first_val_step, forward_only, is_first_microbatch_for_model_chunk(microbatch_id)
            ),
            current_microbatch=current_microbatch,
liangjing's avatar
v1  
liangjing committed
772
        )
773
774
        output_tensors[model_chunk_id].append(output_tensor)

liangjing's avatar
liangjing committed
775
776
777
        nonlocal total_num_tokens
        total_num_tokens += num_tokens.item()

778
779
780
781
782
        # if forward-only, no need to save tensors for a backward pass
        if forward_only:
            input_tensors[model_chunk_id].pop()
            output_tensors[model_chunk_id].pop()

783
784
        return output_tensor

785
    def backward_step_helper(microbatch_id):
786
787
788
        """Helper method to run backward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        backward_step())."""
789
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
790
        parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
791

liangjing's avatar
v1  
liangjing committed
792
793
794
795
796
        # launch grad synchronization (default)
        if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(microbatch_id):
            enable_grad_sync()
            synchronized_model_chunks.add(model_chunk_id)

797
        if parallel_state.is_pipeline_last_stage():
798
799
800
801
802
            if len(output_tensor_grads[model_chunk_id]) == 0:
                output_tensor_grads[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id].pop(0)
        output_tensor = output_tensors[model_chunk_id].pop(0)
        output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
liangjing's avatar
v1  
liangjing committed
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
        input_tensor_grad = backward_step(
            input_tensor, output_tensor, output_tensor_grad, model_type, config
        )

        # launch grad synchronization (custom grad sync)
        # Note: Asynchronous communication tends to slow down compute.
        # To reduce idling from mismatched microbatch times, we launch
        # asynchronous communication at the same time across the
        # pipeline-parallel group.
        if config.grad_sync_func is not None:
            grad_sync_microbatch_id = microbatch_id - pipeline_parallel_rank
            if grad_sync_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
                grad_sync_microbatch_id
            ):
                grad_sync_chunk_id = get_model_chunk_id(grad_sync_microbatch_id, forward=False)
                enable_grad_sync()
liangjing's avatar
liangjing committed
819
                config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters())
liangjing's avatar
v1  
liangjing committed
820
821
                synchronized_model_chunks.add(grad_sync_chunk_id)
        disable_grad_sync()
822
823
824
825

        return input_tensor_grad

    # Run warmup forward passes.
826
    parallel_state.set_virtual_pipeline_model_parallel_rank(0)
liangjing's avatar
v1  
liangjing committed
827
828
829
830
831
    input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config))

    fwd_wait_handles = None
    bwd_wait_handles = None

832
    for k in range(num_warmup_microbatches):
liangjing's avatar
v1  
liangjing committed
833
834
835
836
837
838
839
840
841
842
843
844
845
846

        if fwd_wait_handles is not None:
            for req in fwd_wait_handles:
                req.wait()

        # Decide to checkpoint all layers' activations of the current micro-batch
        if max_outstanding_backprops is not None:
            checkpoint_activations_microbatch = (
                k % max_outstanding_backprops
                >= config.num_microbatches_with_partial_activation_checkpoints
            )
        else:
            checkpoint_activations_microbatch = None

liangjing's avatar
liangjing committed
847
848
849
850
        current_microbatch = get_microbatch_id_in_model_chunk(k, forward=True)
        output_tensor = forward_step_helper(
            k, current_microbatch, checkpoint_activations_microbatch
        )
851
852

        # Determine if tensor should be received from previous stage.
liangjing's avatar
v1  
liangjing committed
853
        next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
854
        recv_prev = True
855
        if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
856
857
            if next_forward_model_chunk_id == 0:
                recv_prev = False
858
        if k == (total_num_microbatches - 1):
859
            recv_prev = False
860
861

        # Don't send tensor downstream if on last stage.
862
        if parallel_state.is_pipeline_last_stage():
863
            output_tensor = None
864
865
866

        # Send and receive tensors as appropriate (send tensors computed
        # in this iteration; receive tensors for next iteration).
liangjing's avatar
v1  
liangjing committed
867
868
869
870
871
872
873
874
875
876
        if not config.overlap_p2p_comm:
            if (
                k == (num_warmup_microbatches - 1)
                and not forward_only
                and not all_warmup_microbatches
            ):
                input_tensor_grad = None
                recv_next = True
                if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False
liangjing's avatar
liangjing committed
877
878
879
880
881
882
883
884
885
                (input_tensor, output_tensor_grad) = (
                    p2p_communication.send_forward_backward_recv_forward_backward(
                        output_tensor,
                        input_tensor_grad,
                        recv_prev=recv_prev,
                        recv_next=recv_next,
                        tensor_shape=tensor_shape,
                        config=config,
                    )
liangjing's avatar
v1  
liangjing committed
886
887
888
889
890
891
892
                )
                output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
            else:
                input_tensor = p2p_communication.send_forward_recv_forward(
                    output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, config=config
                )
            input_tensors[next_forward_model_chunk_id].append(input_tensor)
893
        else:
liangjing's avatar
v1  
liangjing committed
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
            input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
                output_tensor,
                recv_prev=recv_prev,
                tensor_shape=tensor_shape,
                config=config,
                overlap_p2p_comm=True,
            )

            if (
                k == (num_warmup_microbatches - 1)
                and not forward_only
                and not all_warmup_microbatches
            ):
                input_tensor_grad = None
                recv_next = True
                if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

liangjing's avatar
liangjing committed
912
913
914
915
916
917
918
919
                (output_tensor_grad, bwd_wait_handles) = (
                    p2p_communication.send_backward_recv_backward(
                        input_tensor_grad,
                        recv_next=recv_next,
                        tensor_shape=tensor_shape,
                        config=config,
                        overlap_p2p_comm=True,
                    )
liangjing's avatar
v1  
liangjing committed
920
921
922
923
924
925
                )

                output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
            input_tensors[next_forward_model_chunk_id].append(input_tensor)

        deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
926
927
928
929
930
931

    # Run 1F1B in steady state.
    for k in range(num_microbatches_remaining):
        # Forward pass.
        forward_k = k + num_warmup_microbatches

liangjing's avatar
v1  
liangjing committed
932
933
934
935
936
937
938
939
940
        # Decide to checkpoint all layers' activations of the current micro-batch
        if max_outstanding_backprops is not None:
            checkpoint_activations_microbatch = (
                forward_k % max_outstanding_backprops
                >= config.num_microbatches_with_partial_activation_checkpoints
            )
        else:
            checkpoint_activations_microbatch = None

liangjing's avatar
liangjing committed
941
        current_microbatch = get_microbatch_id_in_model_chunk(forward_k, forward=True)
liangjing's avatar
v1  
liangjing committed
942
943
944
945
946
947
948
        if config.overlap_p2p_comm:
            if fwd_wait_handles is not None:
                for req in fwd_wait_handles:
                    req.wait()

            deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)

liangjing's avatar
liangjing committed
949
950
951
            output_tensor = forward_step_helper(
                forward_k, current_microbatch, checkpoint_activations_microbatch
            )
liangjing's avatar
v1  
liangjing committed
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974

            # Determine if current stage has anything to send in either direction,
            # otherwise set tensor to None.
            forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
            parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)

            # Last virtual stage no activation tensor to send
            if parallel_state.is_pipeline_last_stage():
                output_tensor = None

            # Determine if peers are sending, and where in data structure to put
            # received tensors.
            recv_prev = True
            if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
                # First stage is ahead of last stage by (pipeline_parallel_size - 1).
                next_forward_model_chunk_id = get_model_chunk_id(
                    forward_k - (pipeline_parallel_size - 1), forward=True
                )
                if next_forward_model_chunk_id == (num_model_chunks - 1):
                    recv_prev = False
                next_forward_model_chunk_id += 1
            else:
                next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
975

liangjing's avatar
v1  
liangjing committed
976
977
978
979
            # If last iteration, don't receive; we already received one extra
            # before the start of the for loop.
            if k == (num_microbatches_remaining - 1):
                recv_prev = False
980

liangjing's avatar
v1  
liangjing committed
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
            # Send activation tensor to the next stage and receive activation tensor from the
            # previous stage
            input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
                output_tensor,
                recv_prev=recv_prev,
                tensor_shape=tensor_shape,
                config=config,
                overlap_p2p_comm=True,
            )
            # assert fwd_wait_handles is not None

            if bwd_wait_handles is not None:
                for req in bwd_wait_handles:
                    req.wait()

            # Backward pass.
            backward_k = k
            input_tensor_grad = backward_step_helper(backward_k)

            backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
            parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)

            # First virtual stage no activation gradient tensor to send
            if parallel_state.is_pipeline_first_stage():
                input_tensor_grad = None

            # Determine if the current virtual stage has an activation gradient tensor to receive
            recv_next = True
            if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
                next_backward_model_chunk_id = get_model_chunk_id(
                    backward_k - (pipeline_parallel_size - 1), forward=False
                )
                if next_backward_model_chunk_id == 0:
                    recv_next = False
                next_backward_model_chunk_id -= 1
            else:
                next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)

            output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward(
                input_tensor_grad,
                recv_next=recv_next,
                tensor_shape=tensor_shape,
                config=config,
                overlap_p2p_comm=True,
            )

        else:  # no p2p overlap
liangjing's avatar
liangjing committed
1029
1030
1031
            output_tensor = forward_step_helper(
                forward_k, current_microbatch, checkpoint_activations_microbatch
            )
liangjing's avatar
v1  
liangjing committed
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064

            # Backward pass.
            backward_k = k
            input_tensor_grad = backward_step_helper(backward_k)

            # Send output_tensor and input_tensor_grad, receive input_tensor
            # and output_tensor_grad.

            # Determine if current stage has anything to send in either direction,
            # otherwise set tensor to None.
            forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
            parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
            if parallel_state.is_pipeline_last_stage():
                output_tensor = None

            backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
            parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
            if parallel_state.is_pipeline_first_stage():
                input_tensor_grad = None

            # Determine if peers are sending, and where in data structure to put
            # received tensors.
            recv_prev = True
            if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
                # First stage is ahead of last stage by (pipeline_parallel_size - 1).
                next_forward_model_chunk_id = get_model_chunk_id(
                    forward_k - (pipeline_parallel_size - 1), forward=True
                )
                if next_forward_model_chunk_id == (num_model_chunks - 1):
                    recv_prev = False
                next_forward_model_chunk_id += 1
            else:
                next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
1065

liangjing's avatar
v1  
liangjing committed
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
            recv_next = True
            if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
                next_backward_model_chunk_id = get_model_chunk_id(
                    backward_k - (pipeline_parallel_size - 1), forward=False
                )
                if next_backward_model_chunk_id == 0:
                    recv_next = False
                next_backward_model_chunk_id -= 1
            else:
                next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
1077

liangjing's avatar
v1  
liangjing committed
1078
1079
1080
            # If last iteration, don't receive; we already received one extra
            # before the start of the for loop.
            if k == (num_microbatches_remaining - 1):
1081
1082
                recv_prev = False

liangjing's avatar
v1  
liangjing committed
1083
            # Communicate tensors.
liangjing's avatar
liangjing committed
1084
1085
1086
1087
1088
1089
1090
1091
1092
            (input_tensor, output_tensor_grad) = (
                p2p_communication.send_forward_backward_recv_forward_backward(
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
                    recv_next=recv_next,
                    tensor_shape=tensor_shape,
                    config=config,
                )
liangjing's avatar
v1  
liangjing committed
1093
1094
            )
            deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
1095

1096
1097
        # Put input_tensor and output_tensor_grad in data structures in the
        # right location.
1098
1099
1100
        if recv_prev:
            input_tensors[next_forward_model_chunk_id].append(input_tensor)
        if recv_next:
liangjing's avatar
v1  
liangjing committed
1101
1102
1103
            output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)

    deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
1104

1105
    # Run cooldown backward passes (flush out pipeline).
1106
    if not forward_only:
liangjing's avatar
v1  
liangjing committed
1107
1108
1109
1110
        if config.overlap_p2p_comm and bwd_wait_handles is not None:
            for wait_handle in bwd_wait_handles:
                wait_handle.wait()

1111
        if all_warmup_microbatches:
liangjing's avatar
v1  
liangjing committed
1112
1113
1114
            output_tensor_grads[num_model_chunks - 1].append(
                p2p_communication.recv_backward(tensor_shape, config=config)
            )
1115
        for k in range(num_microbatches_remaining, total_num_microbatches):
1116
            input_tensor_grad = backward_step_helper(k)
liangjing's avatar
v1  
liangjing committed
1117
            next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
1118
            recv_next = True
1119
            if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
1120
1121
                if next_backward_model_chunk_id == (num_model_chunks - 1):
                    recv_next = False
1122
            if k == (total_num_microbatches - 1):
1123
1124
                recv_next = False
            output_tensor_grads[next_backward_model_chunk_id].append(
1125
                p2p_communication.send_backward_recv_backward(
liangjing's avatar
v1  
liangjing committed
1126
1127
1128
1129
                    input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, config=config
                )
            )

liangjing's avatar
liangjing committed
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
        # Launch any remaining grad reductions.
        enable_grad_sync()
        if config.grad_sync_func is not None:
            for model_chunk_id in range(num_model_chunks):
                if model_chunk_id not in synchronized_model_chunks:
                    config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
                    synchronized_model_chunks.add(model_chunk_id)

    if config.finalize_model_grads_func is not None and not forward_only:

        # If defer_embedding_wgrad_compute is enabled we need to do the
        # weight gradient GEMM's here.
        finish_embedding_wgrad_compute(config, embedding_module)

        # Finalize model grads (perform full grad all-reduce / reduce-scatter for
        # data parallelism, layernorm all-reduce for sequence parallelism, and
        # embedding all-reduce for pipeline parallelism).
        config.finalize_model_grads_func(
            model, total_num_tokens if config.calculate_per_token_loss else None
        )

    # Restore config.grad_sync_func and config.param_sync_func.
    if forward_only:
        config.grad_sync_func, config.param_sync_func = grad_sync_func, param_sync_func

    if config.timers is not None:
        config.timers('forward-backward').stop()
1157

1158
    return forward_data_store
1159

liangjing's avatar
v1  
liangjing committed
1160
1161
1162
1163
1164
1165
1166
1167
1168

def get_tensor_shapes(
    *,
    rank: int,
    model_type: ModelType,
    seq_length: int,
    micro_batch_size: int,
    decoder_seq_length: int,
    config,
liangjing's avatar
liangjing committed
1169
    encoder_decoder_xattn: bool,
liangjing's avatar
v1  
liangjing committed
1170
):
liangjing's avatar
liangjing committed
1171
1172
1173
1174
1175
1176
1177
1178
1179
    # Determine right tensor sizes (based on position of rank with
    # respect to split rank) and model size.
    # Send two tensors if model decoder requires the encoder's output
    # (via cross-attention) and rank is in decoder stage.
    #     first tensor is decoder.
    #     second tensor is encoder.
    # If model has an encoder & decoder and rank is at the boundary:
    #     send one tensor.
    # Otherwise, send one tensor.
1180
    tensor_shapes = []
1181

liangjing's avatar
liangjing committed
1182
1183
1184
1185
    seq_length = seq_length // parallel_state.get_context_parallel_world_size()
    if model_type == ModelType.encoder_and_decoder:
        decoder_seq_length = decoder_seq_length // parallel_state.get_context_parallel_world_size()

liangjing's avatar
v1  
liangjing committed
1186
    if config.sequence_parallel:
1187
        seq_length = seq_length // parallel_state.get_tensor_model_parallel_world_size()
liangjing's avatar
v1  
liangjing committed
1188
1189
1190
1191
        if model_type == ModelType.encoder_and_decoder:
            decoder_seq_length = (
                decoder_seq_length // parallel_state.get_tensor_model_parallel_world_size()
            )
1192
1193

    if model_type == ModelType.encoder_and_decoder:
liangjing's avatar
liangjing committed
1194
        if parallel_state.is_inside_encoder(rank):
liangjing's avatar
v1  
liangjing committed
1195
            tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
liangjing's avatar
liangjing committed
1196
        elif encoder_decoder_xattn:
liangjing's avatar
v1  
liangjing committed
1197
1198
            tensor_shapes.append((decoder_seq_length, micro_batch_size, config.hidden_size))
            tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
liangjing's avatar
liangjing committed
1199
1200
1201
        else:
            tensor_shapes.append((decoder_seq_length, micro_batch_size, config.hidden_size))
    else:  # model_type == ModelType.encoder_or_decoder
liangjing's avatar
v1  
liangjing committed
1202
        tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
1203
1204
1205
    return tensor_shapes


liangjing's avatar
v1  
liangjing committed
1206
def recv_forward(tensor_shapes, config):
1207
1208
1209
1210
1211
    input_tensors = []
    for tensor_shape in tensor_shapes:
        if tensor_shape is None:
            input_tensors.append(None)
        else:
liangjing's avatar
v1  
liangjing committed
1212
            input_tensors.append(p2p_communication.recv_forward(tensor_shape, config))
1213
1214
1215
    return input_tensors


liangjing's avatar
v1  
liangjing committed
1216
def recv_backward(tensor_shapes, config):
1217
1218
1219
1220
1221
    output_tensor_grads = []
    for tensor_shape in tensor_shapes:
        if tensor_shape is None:
            output_tensor_grads.append(None)
        else:
liangjing's avatar
v1  
liangjing committed
1222
            output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, config))
1223
1224
1225
    return output_tensor_grads


liangjing's avatar
v1  
liangjing committed
1226
def send_forward(output_tensors, tensor_shapes, config):
1227
1228
    if not isinstance(output_tensors, list):
        output_tensors = [output_tensors]
liangjing's avatar
liangjing committed
1229
    for output_tensor, tensor_shape in zip(output_tensors, tensor_shapes):
1230
1231
        if tensor_shape is None:
            continue
liangjing's avatar
v1  
liangjing committed
1232
        p2p_communication.send_forward(output_tensor, config)
1233
1234


liangjing's avatar
v1  
liangjing committed
1235
def send_backward(input_tensor_grads, tensor_shapes, config):
1236
1237
    if not isinstance(input_tensor_grads, list):
        input_tensor_grads = [input_tensor_grads]
liangjing's avatar
liangjing committed
1238
    for input_tensor_grad, tensor_shape in zip(input_tensor_grads, tensor_shapes):
1239
1240
        if tensor_shape is None:
            continue
liangjing's avatar
v1  
liangjing committed
1241
        p2p_communication.send_backward(input_tensor_grad, config)
1242
1243


liangjing's avatar
v1  
liangjing committed
1244
def send_forward_recv_backward(output_tensors, tensor_shapes, config):
1245
1246
1247
    if not isinstance(output_tensors, list):
        output_tensors = [output_tensors]
    output_tensor_grads = []
liangjing's avatar
liangjing committed
1248
    for output_tensor, tensor_shape in zip(output_tensors, tensor_shapes):
1249
1250
1251
1252
        if tensor_shape is None:
            output_tensor_grads.append(None)
            continue
        output_tensor_grad = p2p_communication.send_forward_recv_backward(
liangjing's avatar
v1  
liangjing committed
1253
1254
            output_tensor, tensor_shape, config
        )
1255
1256
1257
1258
        output_tensor_grads.append(output_tensor_grad)
    return output_tensor_grads


liangjing's avatar
v1  
liangjing committed
1259
def send_backward_recv_forward(input_tensor_grads, tensor_shapes, config):
1260
1261
1262
    if not isinstance(input_tensor_grads, list):
        input_tensor_grads = [input_tensor_grads]
    input_tensors = []
liangjing's avatar
liangjing committed
1263
    for input_tensor_grad, tensor_shape in zip(input_tensor_grads, tensor_shapes):
1264
1265
1266
1267
        if tensor_shape is None:
            input_tensors.append(None)
            continue
        input_tensor = p2p_communication.send_backward_recv_forward(
liangjing's avatar
v1  
liangjing committed
1268
1269
            input_tensor_grad, tensor_shape, config
        )
1270
1271
1272
1273
        input_tensors.append(input_tensor)
    return input_tensors


liangjing's avatar
v1  
liangjing committed
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
def forward_backward_pipelining_without_interleaving(
    *,
    forward_step_func,
    data_iterator: Union[Iterator, List[Iterator]],
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    num_microbatches: int,
    seq_length: int,
    micro_batch_size: int,
    decoder_seq_length: int = None,
    forward_only: bool = False,
    collect_non_loss_data: bool = False,
liangjing's avatar
liangjing committed
1285
    first_val_step: bool = None,
liangjing's avatar
v1  
liangjing committed
1286
):
1287
    """Run non-interleaved 1F1B schedule, with communication between pipeline
liangjing's avatar
liangjing committed
1288
    stages. Returns dictionary with losses if the last stage, empty dict otherwise."""
1289

liangjing's avatar
v1  
liangjing committed
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
    if isinstance(model, list):
        assert (
            len(model) == 1
        ), "non-interleaved pipeline parallelism does not support model chunking"
        model = model[0]
    if isinstance(data_iterator, list):
        assert (
            len(data_iterator) == 1
        ), "non-pipeline-parallel schedule does not support model chunking"
        data_iterator = data_iterator[0]

    config = get_model_config(model)
    if config.overlap_p2p_comm:
        raise ValueError(
            "Non-interleaved pipeline parallelism does not support overlapping p2p communication"
        )

liangjing's avatar
liangjing committed
1307
1308
1309
1310
1311
1312
1313
    # Needed only when gradients are finalized in M-Core
    if config.finalize_model_grads_func is not None and not forward_only:
        embedding_module = clear_embedding_activation_buffer(config, model)

    if config.timers is not None:
        config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)

liangjing's avatar
v1  
liangjing committed
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
    # Disable async grad reductions
    no_sync_func = config.no_sync_func
    if no_sync_func is None:
        no_sync_func = contextlib.nullcontext
    no_sync_context = None

    def disable_grad_sync():
        """Disable asynchronous grad reductions"""
        nonlocal no_sync_context
        if no_sync_context is None:
            no_sync_context = no_sync_func()
            no_sync_context.__enter__()

    def enable_grad_sync():
        """Enable asynchronous grad reductions"""
        nonlocal no_sync_context
        if no_sync_context is not None:
            no_sync_context.__exit__(None, None, None)
            no_sync_context = None

    disable_grad_sync()
1335
1336

    # Compute number of warmup microbatches.
liangjing's avatar
v1  
liangjing committed
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
    num_warmup_microbatches = (
        parallel_state.get_pipeline_model_parallel_world_size()
        - parallel_state.get_pipeline_model_parallel_rank()
        - 1
    )
    num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
    num_microbatches_remaining = num_microbatches - num_warmup_microbatches

    # Checkpoint the activations of partial Transformer layers in a number of micro-batches
    # within the maximum outstanding micro-batch backpropagations.
    # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
    # checkpoint partial Transformer layers (or skip checkpointing) and
    # the rest of micro-batches within a window of micro-batches checkpoint
    # all Transformer layers. The window of micro-batches is set by the maximum
    # outstanding backpropagations and becomes smaller at later pipeline stages.
    # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
    max_outstanding_backprops = None
    if config.num_microbatches_with_partial_activation_checkpoints is not None:
        max_outstanding_backprops = num_warmup_microbatches + 1
1356

1357
    model_type = get_model_type(model)
liangjing's avatar
liangjing committed
1358
    encoder_decoder_xattn = get_model_xattn(model)
1359
1360

    rank = parallel_state.get_pipeline_model_parallel_rank()
liangjing's avatar
v1  
liangjing committed
1361
1362
1363
1364
1365
1366
1367
    recv_tensor_shapes = get_tensor_shapes(
        rank=rank - 1,
        model_type=model_type,
        seq_length=seq_length,
        micro_batch_size=micro_batch_size,
        decoder_seq_length=decoder_seq_length,
        config=config,
liangjing's avatar
liangjing committed
1368
        encoder_decoder_xattn=encoder_decoder_xattn,
liangjing's avatar
v1  
liangjing committed
1369
1370
1371
1372
1373
1374
1375
1376
    )
    send_tensor_shapes = get_tensor_shapes(
        rank=rank,
        model_type=model_type,
        seq_length=seq_length,
        micro_batch_size=micro_batch_size,
        decoder_seq_length=decoder_seq_length,
        config=config,
liangjing's avatar
liangjing committed
1377
        encoder_decoder_xattn=encoder_decoder_xattn,
liangjing's avatar
v1  
liangjing committed
1378
    )
1379

1380
1381
1382
    # Input, output tensors only need to be saved when doing backward passes
    input_tensors = None
    output_tensors = None
liangjing's avatar
liangjing committed
1383
1384
    total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()

1385
1386
1387
    if not forward_only:
        input_tensors = []
        output_tensors = []
1388
    forward_data_store = []
1389
1390
1391

    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
liangjing's avatar
v1  
liangjing committed
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
        # Decide to checkpoint all layers' activations of the current micro-batch
        if max_outstanding_backprops is not None:
            checkpoint_activations_microbatch = (
                i % max_outstanding_backprops
                >= config.num_microbatches_with_partial_activation_checkpoints
            )
        else:
            checkpoint_activations_microbatch = None

        input_tensor = recv_forward(recv_tensor_shapes, config)
liangjing's avatar
liangjing committed
1402
        output_tensor, num_tokens = forward_step(
liangjing's avatar
v1  
liangjing committed
1403
1404
1405
1406
1407
1408
1409
1410
1411
            forward_step_func,
            data_iterator,
            model,
            num_microbatches,
            input_tensor,
            forward_data_store,
            config,
            collect_non_loss_data,
            checkpoint_activations_microbatch,
liangjing's avatar
liangjing committed
1412
1413
1414
            check_first_val_step(first_val_step, forward_only, i == 0),
            current_microbatch=i,
            encoder_decoder_xattn=encoder_decoder_xattn,
liangjing's avatar
v1  
liangjing committed
1415
1416
        )
        send_forward(output_tensor, send_tensor_shapes, config)
liangjing's avatar
liangjing committed
1417
        total_num_tokens += num_tokens.item()
1418

1419
1420
1421
        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
liangjing's avatar
v1  
liangjing committed
1422
            deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
1423
1424
1425
1426
1427

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
liangjing's avatar
v1  
liangjing committed
1428
        input_tensor = recv_forward(recv_tensor_shapes, config)
1429
1430
1431

    # Run 1F1B in steady state.
    for i in range(num_microbatches_remaining):
liangjing's avatar
v1  
liangjing committed
1432
        last_iteration = i == (num_microbatches_remaining - 1)
1433

liangjing's avatar
v1  
liangjing committed
1434
1435
1436
1437
1438
1439
1440
1441
        # Decide to checkpoint all layers' activations of the current micro-batch
        if max_outstanding_backprops is not None:
            checkpoint_activations_microbatch = (
                (i + num_warmup_microbatches) % max_outstanding_backprops
            ) >= config.num_microbatches_with_partial_activation_checkpoints
        else:
            checkpoint_activations_microbatch = None

liangjing's avatar
liangjing committed
1442
        output_tensor, num_tokens = forward_step(
liangjing's avatar
v1  
liangjing committed
1443
1444
1445
1446
1447
1448
1449
1450
1451
            forward_step_func,
            data_iterator,
            model,
            num_microbatches,
            input_tensor,
            forward_data_store,
            config,
            collect_non_loss_data,
            checkpoint_activations_microbatch,
liangjing's avatar
liangjing committed
1452
1453
1454
1455
1456
            check_first_val_step(
                first_val_step, forward_only, (i == 0) and (num_warmup_microbatches == 0)
            ),
            current_microbatch=i + num_warmup_microbatches,
            encoder_decoder_xattn=encoder_decoder_xattn,
liangjing's avatar
v1  
liangjing committed
1457
        )
liangjing's avatar
liangjing committed
1458
        total_num_tokens += num_tokens.item()
1459

1460
        if forward_only:
liangjing's avatar
v1  
liangjing committed
1461
            send_forward(output_tensor, send_tensor_shapes, config)
1462
1463

            if not last_iteration:
liangjing's avatar
v1  
liangjing committed
1464
                input_tensor = recv_forward(recv_tensor_shapes, config)
1465

1466
        else:
liangjing's avatar
v1  
liangjing committed
1467
1468
1469
            output_tensor_grad = send_forward_recv_backward(
                output_tensor, send_tensor_shapes, config
            )
1470

1471
1472
1473
            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
liangjing's avatar
v1  
liangjing committed
1474
            deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
1475

1476
1477
1478
1479
            # Pop input_tensor and output_tensor from the start of the list for
            # the backward pass.
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)
1480

liangjing's avatar
liangjing committed
1481
1482
1483
1484
1485
1486
            # Enable grad sync for the last microbatch in the batch if the full
            # backward pass completes in the 1F1B stage.
            if num_warmup_microbatches == 0 and last_iteration:
                if config.grad_sync_func is None or rank == 0:
                    enable_grad_sync()

liangjing's avatar
v1  
liangjing committed
1487
1488
1489
            input_tensor_grad = backward_step(
                input_tensor, output_tensor, output_tensor_grad, model_type, config
            )
1490
1491
1492

            if last_iteration:
                input_tensor = None
liangjing's avatar
v1  
liangjing committed
1493
                send_backward(input_tensor_grad, recv_tensor_shapes, config)
1494
            else:
liangjing's avatar
v1  
liangjing committed
1495
1496
1497
                input_tensor = send_backward_recv_forward(
                    input_tensor_grad, recv_tensor_shapes, config
                )
1498
1499
1500
1501

    # Run cooldown backward passes.
    if not forward_only:
        for i in range(num_warmup_microbatches):
liangjing's avatar
v1  
liangjing committed
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511

            # Enable async grad reduction in the last backward pass
            # Note: If grad sync function is provided, only enable
            # async grad reduction in first pipeline stage. Other
            # pipeline stages do grad reduction during pipeline
            # bubble.
            if i == num_warmup_microbatches - 1:
                if config.grad_sync_func is None or rank == 0:
                    enable_grad_sync()

1512
1513
1514
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

liangjing's avatar
v1  
liangjing committed
1515
1516
1517
1518
1519
            output_tensor_grad = recv_backward(send_tensor_shapes, config)

            input_tensor_grad = backward_step(
                input_tensor, output_tensor, output_tensor_grad, model_type, config
            )
1520

liangjing's avatar
v1  
liangjing committed
1521
            send_backward(input_tensor_grad, recv_tensor_shapes, config)
1522

liangjing's avatar
liangjing committed
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
        # Launch any remaining grad reductions.
        if no_sync_context is not None:
            enable_grad_sync()
            if config.grad_sync_func is not None:
                config.grad_sync_func(model.parameters())

    if config.finalize_model_grads_func is not None and not forward_only:

        # If defer_embedding_wgrad_compute is enabled we need to do the
        # weight gradient GEMM's here.
        finish_embedding_wgrad_compute(config, embedding_module)

        # Finalize model grads (perform full grad all-reduce / reduce-scatter for
        # data parallelism, layernorm all-reduce for sequence parallelism, and
        # embedding all-reduce for pipeline parallelism).
        config.finalize_model_grads_func(
            [model], total_num_tokens if config.calculate_per_token_loss else None
        )

    if config.timers is not None:
        config.timers('forward-backward').stop()
1544

1545
    return forward_data_store