schedules.py 35.3 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2

3
4
5
from contextlib import contextmanager, nullcontext
from typing import Optional, List, Union, Callable, Any

6
import torch
7
from torch.autograd.variable import Variable
8
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
9

10
11
12
13
from megatron.core import parallel_state
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.enums import ModelType
from megatron.core.utils import get_attr_wrapped_model, get_model_type
14

15
16
# Types
Shape = Union[List[int], torch.Size]
17

Jared Casper's avatar
Jared Casper committed
18
def get_forward_backward_func():
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    """Retrieves the appropriate forward_backward function given the
    configuration of parallel_state.

    Returns a function that will perform all of the forward and
    backward passes of the model given the pipeline model parallel
    world size and virtual pipeline model parallel world size in the
    global parallel_state.

    The function returned takes the following arguments:

    forward_step_func (required): A function that takes a data
        iterator and a model as its arguments and return the model's
        forward output and the loss function. The loss function should
        take one torch.Tensor and return a torch.Tensor of loss and a
        dictionary of string -> torch.Tensor.

        For example:

        def loss_func(loss_mask, output_tensor):
            losses = output_tensor.float()
            loss_mask = loss_mask.view(-1).float()
            loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

            # Reduce loss for logging.
            averaged_loss = average_losses_across_data_parallel_group([loss])

            return loss, {'lm loss': averaged_loss[0]}

        def forward_step(data_iterator, model):
            data, loss_mask = next(data_iterator)
            output = model(data)
            return output, partial(loss_func, loss_mask)


        forward_backward_func(forward_step_func=forward_step, ...)


    data_iterator (required): an iterator over the data, will be
        passed as is to forward_step_func

    model (required): the actual model. A torch.nn.Module or, in the
        case or iterleaving, a list of torch.nn.Module

    num_microbatches (int, required):
        The number of microbatches to go through

    dtype (required when using pipeline parallelism): dtype used in
        p2p communication, usually params_dtype

    tensor_shape (required when using pipeline parallelism): Shape of
        tensor. The tensor is expected to be 3D and its order of
        dimension is supposed to be ``(sequence, batch, hidden)``.

    decoder_seq_length (int, required for ModelType.encoder_and_decoder models):
        Sequence length of the decoder portion, used to determine tensor shapes.

    grad_scaler (optional, default=None): If using loss scaling,
        this function should take the loss and return the scaled
        loss. If None, no function is called on the loss.

    sequence_parallel (optional, default=False):
        Set to :obj:`True` for this function to handle sequence
        length.  When :obj:`True`, the sequence length on each tensor
        model parallel rank is updated to
        :math:`original\_sequence\_length /
        tensor\_model\_parallel\_world\_size`.
        TODO: Do we need this? Just roll into tensor_shape arg?

    forward_only (optional, default=False): Perform only the forward step

    timers (optional, default=None): TODO

    collect_non_loss_data: TODO

    """
    pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
    if pipeline_model_parallel_size > 1:
        if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
Jared Casper's avatar
Jared Casper committed
97
98
99
100
101
102
103
            forward_backward_func = forward_backward_pipelining_with_interleaving
        else:
            forward_backward_func = forward_backward_pipelining_without_interleaving
    else:
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func

104
105
def deallocate_output_tensor(out):
    '''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
106
107
108
109
110

    This method should be called right after the output tensor has been
    sent to the next pipeline stage. At this point, the output tensor is
    only useful for its '.grad_fn' field, and not its '.data'.
    '''
Lawrence McAfee's avatar
Lawrence McAfee committed
111
112
    if out is None:
        return
113
114
115
116
117
118
119
120
121
    assert isinstance(out, torch.Tensor), \
        "expected Tensor, found %s." % type(out).__name__
    assert out._base is None, \
        "counter-productive to free a view of another tensor."
    out.data = torch.empty(
        (1,),
        device = out.device,
        dtype = out.dtype,
    )
122

123
def custom_backward(output, grad_output):
124
125
    '''Directly call C++ autograd engine.

126
    To make the 'deallocate_output_tensor' (above) optimization work, the C++
127
128
129
130
    autograd engine must be called directly, bypassing Pytorch's
    torch.autograd.backward. Pytorch's 'backward' checks that the output and
    grad have the same shape, while C++'s 'backward' does not.
    '''
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147

    assert output.numel() == 1, \
        "output should be pseudo-'freed' in schedule, to optimize memory"
    assert isinstance(output, torch.Tensor), \
        "output == '%s'." % type(output).__name__
    assert isinstance(grad_output, (torch.Tensor, type(None))), \
        "grad_output == '%s'." % type(grad_output).__name__

    # Handle scalar output
    if grad_output is None:
        assert output.numel() == 1, "implicit grad requires scalar output."
        grad_output = torch.ones_like(
            output,
            memory_format = torch.preserve_format,
        )

    # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Lawrence McAfee's avatar
Lawrence McAfee committed
148
149
150
151
152
153
154
155
156
    Variable._execution_engine.run_backward(
        tensors = (output,),
        grad_tensors = (grad_output,),
        keep_graph = False,
        create_graph = False,
        inputs = tuple(),
        allow_unreachable=True,
        accumulate_grad=True,
    )
157
158
159
160




Jared Casper's avatar
Jared Casper committed
161

162
163
164
def forward_step(forward_step_func,
                 data_iterator,
                 model,
165
                 num_microbatches,
166
167
                 input_tensor,
                 forward_data_store,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
168
                 timers,
169
                 collect_non_loss_data=False):
170
171
172
173
174
175
    """Forward step for passed-in model.

    If first stage, input tensor is obtained from data_iterator, otherwise
    passed-in input_tensor is used.

    Returns output tensor."""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
176
177
    if timers is not None:
        timers('forward-compute', log_level=2).start()
178
179
180
181
182
183

    unwrap_output_tensor = False
    if not isinstance(input_tensor, list):
        input_tensor = [input_tensor]
        unwrap_output_tensor = True

184
185
186
187
188
189
190
191
    set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
    set_input_tensor(input_tensor)

    context_manager = torch.autocast("cuda") if torch.is_autocast_enabled() else nullcontext()
    with context_manager:
        output_tensor, loss_func = forward_step_func(data_iterator, model)

    if parallel_state.is_pipeline_last_stage():
192
193
194
        if not collect_non_loss_data:
            output_tensor = loss_func(output_tensor)
            loss, loss_reduced = output_tensor
195
            output_tensor = loss / num_microbatches
196
197
198
199
200
            forward_data_store.append(loss_reduced)
        else:
            data = loss_func(output_tensor, non_loss_data=True)
            forward_data_store.append(data)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
201
202
    if timers is not None:
        timers('forward-compute').stop()
203

204
205
206
    # If T5 model (or other model with encoder and decoder)
    # and in decoder stack, then send encoder_hidden_state
    # downstream as well.
207
208
209
210
    model_type = get_model_type(model)

    if parallel_state.is_pipeline_stage_after_split() and \
            model_type == ModelType.encoder_and_decoder:
211
212
213
214
        return [output_tensor, input_tensor[-1]]
    if unwrap_output_tensor:
        return output_tensor
    return [output_tensor]
215
216


217
218
def backward_step(grad_scaler, input_tensor, output_tensor,
                  output_tensor_grad, model_type, timers):
219
220
221
222
223
224
225
    """Backward step through passed-in output tensor.

    If last stage, output_tensor_grad is None, otherwise gradient of loss
    with respect to stage's output tensor.

    Returns gradient of loss with respect to input tensor (None if first
    stage)."""
226
227
228
229

    # NOTE: This code currently can handle at most one skip connection. It
    # needs to be modified slightly to support arbitrary numbers of skip
    # connections.
230

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
231
232
    if timers is not None:
        timers('backward-compute', log_level=2).start()
233
234

    # Retain the grad on the input_tensor.
235
236
237
238
239
240
241
242
243
244
245
246
    unwrap_input_tensor_grad = False
    if not isinstance(input_tensor, list):
        input_tensor = [input_tensor]
        unwrap_input_tensor_grad = True
    for x in input_tensor:
        if x is not None:
            x.retain_grad()

    if not isinstance(output_tensor, list):
        output_tensor = [output_tensor]
    if not isinstance(output_tensor_grad, list):
        output_tensor_grad = [output_tensor_grad]
247
248

    # Backward pass.
249
250
    if output_tensor_grad[0] is None and grad_scaler is not None:
        output_tensor = grad_scaler(output_tensor[0])
251
    custom_backward(output_tensor[0], output_tensor_grad[0])
252
253

    # Collect the grad of the input_tensor.
254
    input_tensor_grad = [None]
255
    if input_tensor is not None:
256
257
258
259
260
261
262
263
264
        input_tensor_grad = []
        for x in input_tensor:
            if x is None:
                input_tensor_grad.append(None)
            else:
                input_tensor_grad.append(x.grad)

    # Handle single skip connection if it exists (encoder_hidden_state in
    # model with encoder and decoder).
265
266
267
    if parallel_state.get_pipeline_model_parallel_world_size() > 1 and \
            parallel_state.is_pipeline_stage_after_split() and \
            model_type == ModelType.encoder_and_decoder:
268
269
270
271
        if output_tensor_grad[1] is not None:
            input_tensor_grad[-1].add_(output_tensor_grad[1])
    if unwrap_input_tensor_grad:
        input_tensor_grad = input_tensor_grad[0]
272

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
273
274
    if timers is not None:
        timers('backward-compute').stop()
275
276
277
278

    return input_tensor_grad


279
280
281
282
283
284
285
286
@contextmanager
def dummy_handler():
    try:
        yield
    finally:
        pass


287
288
289
290
291
292
293
294
295
296
297
298
299
def forward_backward_no_pipelining(*,
                                   forward_step_func,
                                   data_iterator,
                                   model: Union[torch.nn.Module, List[torch.nn.Module]],
                                   num_microbatches: int,
                                   dtype: Optional[torch.dtype] = None, # unused
                                   tensor_shape: Optional[Shape] = None, # unused
                                   decoder_seq_length: Optional[int] = None, # unused
                                   grad_scaler: Callable = None,
                                   sequence_parallel: bool = False, # unused
                                   forward_only: bool = False,
                                   timers: Callable = None,
                                   collect_non_loss_data: bool = False):
300
301
302
    """Run forward and backward passes with no pipeline parallelism
    (no inter-stage communication).

303
304
305
306
307
    Returns dictionary with losses.


    See get_forward_backward_func() for argument details
    """
308
309
310
    assert len(model) == 1
    model = model[0]

311
312
313
314
    context_handler = dummy_handler
    if isinstance(model, torchDDP):
        context_handler = model.no_sync

315
316
    model_type = get_model_type(model)

317
    forward_data_store = []
318
319
    input_tensor, output_tensor_grad = None, None
    with context_handler():
320
        for i in range(num_microbatches - 1):
321
            output_tensor = forward_step(forward_step_func, data_iterator,
322
                                         model, num_microbatches, input_tensor, forward_data_store,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
323
                                         timers, collect_non_loss_data)
324
            if not forward_only:
325
326
                backward_step(grad_scaler, input_tensor, output_tensor,
                              output_tensor_grad, model_type, timers)
327
328
329

    # Run computation for last microbatch out of context handler (want to
    # synchronize gradients).
330
    output_tensor = forward_step(forward_step_func, data_iterator,
331
                                 model, num_microbatches, input_tensor, forward_data_store,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
332
                                 timers, collect_non_loss_data)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
333

334
    if not forward_only:
335
336
        backward_step(grad_scaler, input_tensor, output_tensor,
                      output_tensor_grad, model_type, timers)
337

338
    return forward_data_store
339
340


341
342
343
344
345
346
347
348
349
350
351
352
353
def forward_backward_pipelining_with_interleaving(*,
                                                  forward_step_func,
                                                  data_iterator,
                                                  model: Union[torch.nn.Module, List[torch.nn.Module]],
                                                  num_microbatches: int,
                                                  dtype: torch.dtype,
                                                  tensor_shape: Shape,
                                                  decoder_seq_length: Optional[int] = None,
                                                  grad_scaler: Callable = None,
                                                  sequence_parallel: bool = False,
                                                  forward_only: bool = False,
                                                  timers: Callable = None,
                                                  collect_non_loss_data: bool = False):
354
355
356
357
    """Run interleaved 1F1B schedule (model split into model chunks), with
    communication between pipeline stages as needed.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
358

359
360
    input_tensors = [[] for _ in range(len(model))]
    output_tensors = [[] for _ in range(len(model))]
361
    forward_data_store = []
362
363
364
    if not forward_only:
        output_tensor_grads = [[] for _ in range(len(model))]

365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
    pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()

    if num_microbatches % pipeline_parallel_size != 0:
        msg = f'number of microbatches ({num_microbatches}) is not divisible by '
        msg += f'pipeline-model-parallel-size ({pipeline_parallel_size}) '
        msg += 'when using interleaved schedule'
        raise RuntimeError(msg)

    model_type = get_model_type(model[0])
    if model_type == ModelType.encoder_and_decoder:
        raise RuntimeError("Interleaving is not supported with an encoder and decoder model.")

    if decoder_seq_length is not None and decoder_seq_length != tensor_shape[0]:
        raise RuntimeError("Interleaving is not supported with a different decoder sequence length.")

    if sequence_parallel:
        seq_length, batch_size, hidden = tensor_shape
        tensor_shape = (
            seq_length // parallel_state.get_tensor_model_parallel_world_size(),
            batch_size,
            hidden,
        )
388
389
390

    # Compute number of warmup and remaining microbatches.
    num_model_chunks = len(model)
391
    total_num_microbatches = num_microbatches * num_model_chunks
392
393
    all_warmup_microbatches = False
    if forward_only:
394
        num_warmup_microbatches = total_num_microbatches
395
    else:
396
397
398
399
400
401
        # Run all forward passes and then all backward passes if number of
        # microbatches is just the number of pipeline stages.
        # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
        # all workers, followed by more microbatches after depending on
        # stage ID (more forward passes for earlier stages, later stages can
        # immediately start with 1F1B).
402
403
        if num_microbatches == pipeline_parallel_size:
            num_warmup_microbatches = total_num_microbatches
404
405
406
            all_warmup_microbatches = True
        else:
            num_warmup_microbatches = \
407
                (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
408
409
410
            num_warmup_microbatches += (
                num_model_chunks - 1) * pipeline_parallel_size
            num_warmup_microbatches = min(num_warmup_microbatches,
411
                                          total_num_microbatches)
412
    num_microbatches_remaining = \
413
        total_num_microbatches - num_warmup_microbatches
414

415
    def get_model_chunk_id(microbatch_id, forward):
416
        """Helper method to get the model chunk ID given the iteration number."""
417
418
        microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
        model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
419
        if not forward:
420
421
            model_chunk_id = (num_model_chunks - model_chunk_id - 1)
        return model_chunk_id
422

423
    def forward_step_helper(microbatch_id):
424
425
426
        """Helper method to run forward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        forward_step())."""
427
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
428
        parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
429

430
        # forward step
431
        if parallel_state.is_pipeline_first_stage():
432
433
            if len(input_tensors[model_chunk_id]) == \
                    len(output_tensors[model_chunk_id]):
434
435
                input_tensors[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id][-1]
436
437
        output_tensor = forward_step(forward_step_func,
                                     data_iterator[model_chunk_id],
438
                                     model[model_chunk_id],
439
440
                                     num_microbatches,
                                     input_tensor,
441
                                     forward_data_store,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
442
                                     timers,
443
                                     collect_non_loss_data)
444
445
        output_tensors[model_chunk_id].append(output_tensor)

446
447
448
449
450
        # if forward-only, no need to save tensors for a backward pass
        if forward_only:
            input_tensors[model_chunk_id].pop()
            output_tensors[model_chunk_id].pop()

451
452
        return output_tensor

453
    def backward_step_helper(microbatch_id):
454
455
456
        """Helper method to run backward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        backward_step())."""
457
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
458
        parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
459

460
        if parallel_state.is_pipeline_last_stage():
461
462
463
464
465
466
            if len(output_tensor_grads[model_chunk_id]) == 0:
                output_tensor_grads[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id].pop(0)
        output_tensor = output_tensors[model_chunk_id].pop(0)
        output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
        input_tensor_grad = \
467
            backward_step(grad_scaler,
468
469
                          input_tensor,
                          output_tensor,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
470
                          output_tensor_grad,
471
                          model_type,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
472
                          timers)
473
474
475
476

        return input_tensor_grad

    # Run warmup forward passes.
477
    parallel_state.set_virtual_pipeline_model_parallel_rank(0)
478
    input_tensors[0].append(
479
        p2p_communication.recv_forward(tensor_shape, dtype, timers=timers))
480
481
    for k in range(num_warmup_microbatches):
        output_tensor = forward_step_helper(k)
482
483

        # Determine if tensor should be received from previous stage.
484
485
        next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
        recv_prev = True
486
        if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
487
488
            if next_forward_model_chunk_id == 0:
                recv_prev = False
489
        if k == (total_num_microbatches - 1):
490
            recv_prev = False
491
492

        # Don't send tensor downstream if on last stage.
493
        if parallel_state.is_pipeline_last_stage():
494
            output_tensor = None
495
496
497

        # Send and receive tensors as appropriate (send tensors computed
        # in this iteration; receive tensors for next iteration).
498
499
500
501
        if k == (num_warmup_microbatches - 1) and not forward_only and \
                not all_warmup_microbatches:
            input_tensor_grad = None
            recv_next = True
502
            if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
503
504
                recv_next = False
            input_tensor, output_tensor_grad = \
505
                p2p_communication.send_forward_backward_recv_forward_backward(
506
507
                        output_tensor, input_tensor_grad,
                        recv_prev=recv_prev, recv_next=recv_next,
508
                        tensor_shape=tensor_shape, dtype=dtype,
509
510
511
                        timers=timers)
            output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
        else:
512
            input_tensor = \
513
                p2p_communication.send_forward_recv_forward(
514
                    output_tensor, recv_prev=recv_prev,
515
                    tensor_shape=tensor_shape, dtype=dtype,
516
                    timers=timers)
517
        input_tensors[next_forward_model_chunk_id].append(input_tensor)
518
        deallocate_output_tensor(output_tensor)
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535

    # Run 1F1B in steady state.
    for k in range(num_microbatches_remaining):
        # Forward pass.
        forward_k = k + num_warmup_microbatches
        output_tensor = forward_step_helper(forward_k)

        # Backward pass.
        backward_k = k
        input_tensor_grad = backward_step_helper(backward_k)

        # Send output_tensor and input_tensor_grad, receive input_tensor
        # and output_tensor_grad.

        # Determine if current stage has anything to send in either direction,
        # otherwise set tensor to None.
        forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
536
537
        parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
        if parallel_state.is_pipeline_last_stage():
538
539
540
            output_tensor = None

        backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
541
542
        parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
        if parallel_state.is_pipeline_first_stage():
543
544
545
546
547
            input_tensor_grad = None

        # Determine if peers are sending, and where in data structure to put
        # received tensors.
        recv_prev = True
548
        if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
549
550
551
552
553
554
555
            # First stage is ahead of last stage by (pipeline_parallel_size - 1).
            next_forward_model_chunk_id = get_model_chunk_id(
                forward_k - (pipeline_parallel_size - 1), forward=True)
            if next_forward_model_chunk_id == (num_model_chunks - 1):
                recv_prev = False
            next_forward_model_chunk_id += 1
        else:
556
557
            next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
                                                             forward=True)
558
559

        recv_next = True
560
        if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
561
562
563
564
565
566
567
            # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
            next_backward_model_chunk_id = get_model_chunk_id(
                backward_k - (pipeline_parallel_size - 1), forward=False)
            if next_backward_model_chunk_id == 0:
                recv_next = False
            next_backward_model_chunk_id -= 1
        else:
568
569
            next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
                                                              forward=False)
570

571
572
        # If last iteration, don't receive; we already received one extra
        # before the start of the for loop.
573
574
575
576
577
        if k == (num_microbatches_remaining - 1):
            recv_prev = False

        # Communicate tensors.
        input_tensor, output_tensor_grad = \
578
            p2p_communication.send_forward_backward_recv_forward_backward(
579
580
                    output_tensor, input_tensor_grad,
                    recv_prev=recv_prev, recv_next=recv_next,
581
                    tensor_shape=tensor_shape, dtype=dtype, timers=timers)
582
        deallocate_output_tensor(output_tensor)
583

584
585
        # Put input_tensor and output_tensor_grad in data structures in the
        # right location.
586
587
588
        if recv_prev:
            input_tensors[next_forward_model_chunk_id].append(input_tensor)
        if recv_next:
589
590
            output_tensor_grads[next_backward_model_chunk_id].append(
                output_tensor_grad)
591

592
    # Run cooldown backward passes (flush out pipeline).
593
594
595
    if not forward_only:
        if all_warmup_microbatches:
            output_tensor_grads[num_model_chunks-1].append(
596
                p2p_communication.recv_backward(tensor_shape, dtype=dtype, timers=timers))
597
        for k in range(num_microbatches_remaining, total_num_microbatches):
598
599
600
            input_tensor_grad = backward_step_helper(k)
            next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
            recv_next = True
601
            if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
602
603
                if next_backward_model_chunk_id == (num_model_chunks - 1):
                    recv_next = False
604
            if k == (total_num_microbatches - 1):
605
606
                recv_next = False
            output_tensor_grads[next_backward_model_chunk_id].append(
607
                p2p_communication.send_backward_recv_backward(
608
                    input_tensor_grad, recv_next=recv_next,
609
                    tensor_shape=tensor_shape, dtype=dtype,
610
                    timers=timers))
611

612
    return forward_data_store
613

614
615
616
617
618
619
def get_tensor_shapes(*,
                      rank: int,
                      model_type: ModelType,
                      tensor_shape: Shape,
                      decoder_seq_length: int,
                      sequence_parallel: bool):
620
621
622
623
624
625
626
627
628
    # Determine right tensor sizes (based on position of rank with respect to split
    # rank) and model size.
    # Send two tensors if model is T5 and rank is in decoder stage:
    #     first tensor is decoder (pre-transpose),
    #     second tensor is encoder (post-transpose).
    # If model is T5 and rank is at the boundary:
    #     send one tensor (post-transpose from encoder).
    # Otherwise, send one tensor (pre-transpose).
    tensor_shapes = []
629

630
631
632
633
634
635
636
637
    assert (
        len(tensor_shape) == 3
    ), f"`tensor_shape` should be [sequence_length, micro_batch_size, hidden_size] but {tensor_shape}"

    seq_length, micro_batch_size, hidden_size = tensor_shape

    if sequence_parallel:
        seq_length = seq_length // parallel_state.get_tensor_model_parallel_world_size()
638
639

    if model_type == ModelType.encoder_and_decoder:
640
641
        if sequence_parallel:
            decoder_seq_length = decoder_seq_length // parallel_state.get_tensor_model_parallel_world_size()
642

643
644
        if parallel_state.is_pipeline_stage_before_split(rank):
            tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
645
        else:
646
647
            tensor_shapes.append((decoder_seq_length, micro_batch_size, hidden_size))
            tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
648
    else:
649
        tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
650
651
652
    return tensor_shapes


653
654

def recv_forward(tensor_shapes, dtype, timers):
655
656
657
658
659
    input_tensors = []
    for tensor_shape in tensor_shapes:
        if tensor_shape is None:
            input_tensors.append(None)
        else:
660
            input_tensors.append(p2p_communication.recv_forward(tensor_shape, dtype,
661
662
663
664
                                                                timers=timers))
    return input_tensors


665
def recv_backward(tensor_shapes, dtype, timers):
666
667
668
669
670
    output_tensor_grads = []
    for tensor_shape in tensor_shapes:
        if tensor_shape is None:
            output_tensor_grads.append(None)
        else:
671
            output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, dtype,
672
673
674
675
676
677
678
679
680
681
                                                                       timers=timers))
    return output_tensor_grads


def send_forward(output_tensors, tensor_shapes, timers):
    if not isinstance(output_tensors, list):
        output_tensors = [output_tensors]
    for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
        if tensor_shape is None:
            continue
682
        p2p_communication.send_forward(output_tensor, timers=timers)
683
684
685
686
687
688
689
690


def send_backward(input_tensor_grads, tensor_shapes, timers):
    if not isinstance(input_tensor_grads, list):
        input_tensor_grads = [input_tensor_grads]
    for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
        if tensor_shape is None:
            continue
691
        p2p_communication.send_backward(input_tensor_grad, timers=timers)
692
693


694
def send_forward_recv_backward(output_tensors, tensor_shapes, dtype, timers):
695
696
697
698
699
700
701
702
    if not isinstance(output_tensors, list):
        output_tensors = [output_tensors]
    output_tensor_grads = []
    for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
        if tensor_shape is None:
            output_tensor_grads.append(None)
            continue
        output_tensor_grad = p2p_communication.send_forward_recv_backward(
703
                output_tensor, tensor_shape, dtype, timers=timers)
704
705
706
707
        output_tensor_grads.append(output_tensor_grad)
    return output_tensor_grads


708
def send_backward_recv_forward(input_tensor_grads, tensor_shapes, dtype, timers):
709
710
711
712
713
714
715
716
    if not isinstance(input_tensor_grads, list):
        input_tensor_grads = [input_tensor_grads]
    input_tensors = []
    for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
        if tensor_shape is None:
            input_tensors.append(None)
            continue
        input_tensor = p2p_communication.send_backward_recv_forward(
717
                input_tensor_grad, tensor_shape, dtype, timers=timers)
718
719
720
721
        input_tensors.append(input_tensor)
    return input_tensors


722
723
def forward_backward_pipelining_without_interleaving(*,
                                                     forward_step_func,
724
                                                     data_iterator,
725
726
727
728
729
730
731
732
733
734
                                                     model: Union[torch.nn.Module, List[torch.nn.Module]],
                                                     num_microbatches: int,
                                                     dtype: torch.dtype,
                                                     tensor_shape: Shape,
                                                     decoder_seq_length: Optional[int] = None,
                                                     grad_scaler: Callable = None,
                                                     sequence_parallel: bool = False,
                                                     forward_only: bool = False,
                                                     timers: Callable = None,
                                                     collect_non_loss_data: bool = False):
735
736
737
738
    """Run non-interleaved 1F1B schedule, with communication between pipeline
    stages.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
739

740
741
742
743
744
    assert len(model) == 1
    model = model[0]

    # Compute number of warmup microbatches.
    num_warmup_microbatches = \
745
746
        (parallel_state.get_pipeline_model_parallel_world_size() -
         parallel_state.get_pipeline_model_parallel_rank() - 1)
747
748
749
750
751
752
    num_warmup_microbatches = min(
        num_warmup_microbatches,
        num_microbatches)
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

753
754
755
756
757
758
759
760
761
762
763
764
765
    model_type = get_model_type(model)

    rank = parallel_state.get_pipeline_model_parallel_rank()
    recv_tensor_shapes = get_tensor_shapes(rank=rank-1,
                                           model_type=model_type,
                                           tensor_shape=tensor_shape,
                                           decoder_seq_length=decoder_seq_length,
                                           sequence_parallel=sequence_parallel)
    send_tensor_shapes = get_tensor_shapes(rank=rank,
                                           model_type=model_type,
                                           tensor_shape=tensor_shape,
                                           decoder_seq_length=decoder_seq_length,
                                           sequence_parallel=sequence_parallel)
766

767
768
769
770
771
772
    # Input, output tensors only need to be saved when doing backward passes
    input_tensors = None
    output_tensors = None
    if not forward_only:
        input_tensors = []
        output_tensors = []
773
    forward_data_store = []
774
775
776

    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
777
778
        input_tensor = recv_forward(recv_tensor_shapes, dtype, timers=timers)
        output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches,input_tensor, forward_data_store,timers, collect_non_loss_data)
779
        send_forward(output_tensor, send_tensor_shapes, timers=timers)
780

781
782
783
        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
784
            deallocate_output_tensor(output_tensor[0])
785
786
787
788
789

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
790
        input_tensor = recv_forward(recv_tensor_shapes, dtype, timers=timers)
791
792
793
794
795

    # Run 1F1B in steady state.
    for i in range(num_microbatches_remaining):
        last_iteration = (i == (num_microbatches_remaining - 1))

796
        output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches,
797
                                     input_tensor, forward_data_store,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
798
                                     timers, collect_non_loss_data)
799

800
        if forward_only:
801
            send_forward(output_tensor, send_tensor_shapes, timers=timers)
802
803

            if not last_iteration:
804
                input_tensor = recv_forward(recv_tensor_shapes, dtype, timers=timers)
805

806
        else:
807
            output_tensor_grad = \
808
                send_forward_recv_backward(output_tensor,
809
                                           send_tensor_shapes, dtype,
810
                                           timers=timers)
811

812
813
814
            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
815
            deallocate_output_tensor(output_tensor[0])
816

817
818
819
820
            # Pop input_tensor and output_tensor from the start of the list for
            # the backward pass.
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)
821
822

            input_tensor_grad = \
823
824
                backward_step(grad_scaler, input_tensor, output_tensor,
                              output_tensor_grad, model_type, timers)
825
826
827

            if last_iteration:
                input_tensor = None
828
                send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
829
            else:
830
                input_tensor = \
831
                    send_backward_recv_forward(
832
                        input_tensor_grad, recv_tensor_shapes, dtype, timers=timers)
833
834
835
836
837
838
839

    # Run cooldown backward passes.
    if not forward_only:
        for i in range(num_warmup_microbatches):
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

840
            output_tensor_grad = recv_backward(send_tensor_shapes, dtype, timers=timers)
841
842

            input_tensor_grad = \
843
844
                backward_step(grad_scaler, input_tensor, output_tensor,
                              output_tensor_grad, model_type, timers)
845

846
            send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
847

848
    return forward_data_store