distributed.py 35.8 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
#
# See LICENSE for license information.

"""Methods needed for distributed training (DP/TP)."""
6
7
import warnings
from contextlib import contextmanager, AbstractContextManager, ContextDecorator
8
from typing import Any, Dict, Union, Optional, Callable, Tuple, List
9

Przemek Tredak's avatar
Przemek Tredak committed
10
import torch
11
from torch.cuda import _lazy_call, _lazy_init
12
from torch.utils.checkpoint import detach_variable, noop_context_fn
13
14
15
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
Przemek Tredak's avatar
Przemek Tredak committed
16
17
18

from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type
19
from .fp8 import FP8GlobalStateManager
20
from .float8_tensor import Float8Tensor
Przemek Tredak's avatar
Przemek Tredak committed
21

22
23
24
25

__all__ = ["checkpoint", "CudaRNGStatesTracker"]


Przemek Tredak's avatar
Przemek Tredak committed
26
27
28
29
30
31
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
    "tensor_model_parallel": False,
    "partition_dim": -1,
    "partition_stride": 1,
}

32
33
_USE_REENTRANT_ACTIVATION_RECOMPUTE = True

34
35
36
_FP8_ACTIVATION_RECOMPUTE_ENABLED = False
_FP8_ACTIVATION_RECOMPUTE_PHASE = False

Przemek Tredak's avatar
Przemek Tredak committed
37

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
_ALL_ACTIVE_RNG_STATES = {}


def get_all_rng_states() -> bool:
    """Returns all generator states used by `CudaRNGStatesTracker`."""
    return _ALL_ACTIVE_RNG_STATES


def set_all_rng_states(states: List) -> None:
    """Updates all generator states used by `CudaRNGStatesTracker`."""
    global _ALL_ACTIVE_RNG_STATES
    _ALL_ACTIVE_RNG_STATES = states


def graph_safe_rng_available() -> bool:
    """Returns whether cuda graph safe RNG state manipulation is supported."""
54
55
56
57
58
59
    return (
        hasattr(torch.cuda.CUDAGraph, "register_generator_state")
        and hasattr(torch.Generator, "graphsafe_set_state")
        and hasattr(torch.Generator, "graphsafe_get_state")
        and hasattr(torch.Generator, "clone_state")
    )
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89


def _get_cuda_rng_state(
    device: Union[int, str, torch.device] = "cuda",
    clone: bool = False,
    graph_safe: bool = True,
) -> torch.Tensor:
    """Return the random number generator state of the specified GPU."""

    _lazy_init()
    if isinstance(device, str):
        device = torch.device(device)
    elif isinstance(device, int):
        device = torch.device("cuda", device)
    idx = device.index
    if idx is None:
        idx = torch.cuda.current_device()
    default_generator = torch.cuda.default_generators[idx]
    if graph_safe_rng_available() and graph_safe:
        if clone:
            # Reference to the cloned generator state
            return default_generator.clone_state()
        # Reference to the current generator state
        return default_generator.graphsafe_get_state()
    return default_generator.get_state()


def _set_cuda_rng_state(
    new_state: torch.Tensor,
    device: Union[int, str] = -1,
90
    graph_safe=True,
91
92
) -> None:
    """Sets the random number generator state of the current GPU."""
Przemek Tredak's avatar
Przemek Tredak committed
93
94
95
96
97
98
99
100
101
102
103
104
105

    if device == -1:
        device = torch.device("cuda")
    elif isinstance(device, str):
        device = torch.device(device)
    elif isinstance(device, int):
        device = torch.device("cuda", device)

    def cb() -> None:
        idx = device.index
        if idx is None:
            idx = torch.cuda.current_device()
        default_generator = torch.cuda.default_generators[idx]
106
107
108
        if graph_safe_rng_available() and graph_safe:
            default_generator.graphsafe_set_state(new_state)
            return
Przemek Tredak's avatar
Przemek Tredak committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        default_generator.set_state(new_state)

    _lazy_call(cb)


def set_tensor_model_parallel_attributes(
    tensor: torch.Tensor, is_parallel: bool, dim: int, stride: int
) -> None:
    """set attributes needed for TP"""
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        assert not hasattr(tensor, attribute)
    # Set the attributes.
    setattr(tensor, "tensor_model_parallel", is_parallel)
    setattr(tensor, "partition_dim", dim)
    setattr(tensor, "partition_stride", stride)


def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int:
    """Return world size for the distributed group."""
128
    if not torch.distributed.is_initialized():
Przemek Tredak's avatar
Przemek Tredak committed
129
130
131
132
133
134
        return 1
    return torch.distributed.get_world_size(group=group)


def get_distributed_rank(group: Optional[dist_group_type] = None) -> int:
    """Return my rank for the distributed group."""
135
    assert torch.distributed.is_initialized(), "torch.distributed is not initialized."
Przemek Tredak's avatar
Przemek Tredak committed
136
137
138
139
140
141
142
    return torch.distributed.get_rank(group=group)


def initialize_affine_weight_gpu(
    weight: torch.Tensor,
    init_method: Callable,
    get_rng_state_tracker: Callable,
143
    partition_dim: int = 0,
Przemek Tredak's avatar
Przemek Tredak committed
144
    stride: int = 1,
145
    set_tp_attributes: bool = True,
Przemek Tredak's avatar
Przemek Tredak committed
146
147
148
) -> None:
    """Initialize affine weight for model parallel on GPU."""

149
150
151
152
    if set_tp_attributes:
        set_tensor_model_parallel_attributes(
            tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
        )
Przemek Tredak's avatar
Przemek Tredak committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181

    if get_rng_state_tracker is None:
        init_method(weight)
        return

    with get_rng_state_tracker().fork():
        init_method(weight)


def split_tensor_into_1d_equal_chunks(
    tensor: torch.Tensor, tp_group: dist_group_type, new_buffer: bool = False
) -> torch.Tensor:
    """Break a tensor into equal 1D chunks."""
    partition_size = torch.numel(tensor) // get_distributed_world_size(tp_group)
    start_index = partition_size * get_distributed_rank(tp_group)
    end_index = start_index + partition_size
    if new_buffer:
        data = torch.empty(
            partition_size,
            dtype=tensor.dtype,
            device=torch.cuda.current_device(),
            requires_grad=False,
        )
        data.copy_(tensor.view(-1)[start_index:end_index])
    else:
        data = tensor.view(-1)[start_index:end_index]
    return data


182
def gather_split_1d_tensor(tensor: torch.Tensor, tp_group: dist_group_type) -> torch.Tensor:
Przemek Tredak's avatar
Przemek Tredak committed
183
184
185
186
187
188
189
190
    """Opposite of above function, gather values from model parallel ranks."""
    numel_gathered = torch.numel(tensor) * get_distributed_world_size(tp_group)
    gathered = torch.empty(
        numel_gathered,
        dtype=tensor.dtype,
        device=torch.cuda.current_device(),
        requires_grad=False,
    )
191
    torch.distributed.all_gather_into_tensor(gathered, tensor, group=tp_group)
Przemek Tredak's avatar
Przemek Tredak committed
192
193
194
    return gathered


195
class activation_recompute_forward(AbstractContextManager, ContextDecorator):
196
197
198
199
200
201
202
    """Context manager used to control the forward runtime behavior when executed
    under the `CheckpointFunction` function. For running FP8, the forward pass will
    run without storing intermediate activations. Instead, the forward pass saves
    the inputs tuple and the calling function. In the backwards pass, these are
    retrieved, and the forward pass is computed again while tracking the intermediate
    activations, followed by calculation of gradients using these values.
    """
203
204

    def __init__(self, activation_recompute: bool = False, recompute_phase: bool = False):
205
206
207
208
209
210
        super().__init__()
        self.activation_recompute = activation_recompute
        self.recompute_phase = recompute_phase

    def __enter__(self):
        global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
211
        _FP8_ACTIVATION_RECOMPUTE_ENABLED = (
212
213
214
215
216
217
            self.activation_recompute and FP8GlobalStateManager.is_fp8_enabled()
        )
        _FP8_ACTIVATION_RECOMPUTE_PHASE = self.recompute_phase

    def __exit__(self, *exc_details):
        global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        _FP8_ACTIVATION_RECOMPUTE_ENABLED = False
        _FP8_ACTIVATION_RECOMPUTE_PHASE = False


def is_fp8_activation_recompute_enabled() -> bool:
    """Return global boolean"""
    return _FP8_ACTIVATION_RECOMPUTE_ENABLED


def in_fp8_activation_recompute_phase() -> bool:
    """Return global boolean"""
    return _FP8_ACTIVATION_RECOMPUTE_PHASE


232
233
234
235
236
237
238
def _get_active_autocast_contexts():
    """
    Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state
    at the time of this function's execution.
    """
    autocast_cached = torch.is_autocast_cache_enabled()

239
240
    gpu_autocast_enabled = torch.is_autocast_enabled()
    gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
241
    gpu_autocast_ctx = torch.cuda.amp.autocast(
242
243
        gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
    )
244

245
246
    cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
    cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
247
    cpu_autocast_ctx = torch.cpu.amp.autocast(
248
249
        cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
    )
250
251
252
253

    return gpu_autocast_ctx, cpu_autocast_ctx


254
class _CheckpointFunction(torch.autograd.Function):
Przemek Tredak's avatar
Przemek Tredak committed
255
256
257
258
259
260
261
262
263
264
265
266
    """This function is adapted from torch.utils.checkpoint with
    two main changes:
        1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
        2) the states in the model parallel tracker are also properly
           tracked/set/reset.
    """

    @staticmethod
    def forward(
        ctx,
        run_function: Callable,
        distribute_saved_activations: bool,
267
268
269
        get_rng_state_tracker: Union[Callable, None],
        tp_group: Union[dist_group_type, None],
        context_fn: Union[Callable, None],
270
        kwargs: Dict[str, Any],
Przemek Tredak's avatar
Przemek Tredak committed
271
272
        *args: Tuple[torch.Tensor, ...],
    ) -> Tuple[torch.Tensor, ...]:
273
274
        """Call forward function while saving state to be able to
        redo the computation later."""
Przemek Tredak's avatar
Przemek Tredak committed
275
276
277
278
279
        ctx.run_function = run_function
        ctx.distribute_saved_activations = distribute_saved_activations

        # Copy the rng states.
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
280
        ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
281
282
283
284
285
286
287
        if get_rng_state_tracker is not None:
            ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()

        if context_fn is not None:
            forward_ctx, recompute_ctx = context_fn()
        else:
            forward_ctx, recompute_ctx = noop_context_fn()
288
289
290
291

        # Preserve torch autocast context for the backward pass
        torch_gpu_amp_ctx, torch_cpu_amp_ctx = _get_active_autocast_contexts()

292
        with torch.no_grad(), forward_ctx:
293
            with activation_recompute_forward(activation_recompute=True, recompute_phase=False):
294
                outputs = run_function(*args, **kwargs)
Przemek Tredak's avatar
Przemek Tredak committed
295
296
297
298
299
300
301

        # Divide hidden states across model parallel group and only keep
        # the chunk corresponding to the current rank.
        if distribute_saved_activations:
            ctx.input_0_shape = args[0].data.shape
            safely_set_viewless_tensor_data(
                args[0],
302
                split_tensor_into_1d_equal_chunks(args[0].data, tp_group, new_buffer=True),
Przemek Tredak's avatar
Przemek Tredak committed
303
304
305
            )

        # Store everything.
306
307
308
309
        ctx.inputs = [arg if not torch.is_tensor(arg) else None for arg in args]
        tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args]
        ctx.save_for_backward(*tensor_inputs)

310
        ctx.get_rng_state_tracker = get_rng_state_tracker
Przemek Tredak's avatar
Przemek Tredak committed
311
        ctx.tp_group = tp_group
312
        ctx.recompute_ctx = recompute_ctx
313
314
        ctx.torch_gpu_amp_ctx = torch_gpu_amp_ctx
        ctx.torch_cpu_amp_ctx = torch_cpu_amp_ctx
315
        ctx.kwargs = kwargs
Przemek Tredak's avatar
Przemek Tredak committed
316
317
318
319
320

        return outputs

    @staticmethod
    def backward(
321
        ctx, *args: Tuple[Union[torch.Tensor, None], ...]
Przemek Tredak's avatar
Przemek Tredak committed
322
    ) -> Tuple[Union[torch.Tensor, None], ...]:
323
        """Call backward function with activation recomputation."""
Przemek Tredak's avatar
Przemek Tredak committed
324
325
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
326
                "Checkpointing is not compatible with .grad(), please use .backward() if possible"
Przemek Tredak's avatar
Przemek Tredak committed
327
            )
328
329

        inputs = tuple(
330
            t if t is not None else arg for (t, arg) in zip(ctx.saved_tensors, ctx.inputs)
331
332
        )

333
        get_rng_state_tracker = ctx.get_rng_state_tracker
Przemek Tredak's avatar
Przemek Tredak committed
334
335
336
337

        if ctx.distribute_saved_activations:
            safely_set_viewless_tensor_data(
                inputs[0],
338
                gather_split_1d_tensor(inputs[0].data, ctx.tp_group).view(ctx.input_0_shape),
Przemek Tredak's avatar
Przemek Tredak committed
339
340
341
342
            )

        # Store the current states.
        bwd_cpu_rng_state = torch.get_rng_state()
343
        bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
344
345
        if get_rng_state_tracker is not None:
            bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
Przemek Tredak's avatar
Przemek Tredak committed
346
347
348

        # Set the states to what it used to be before the forward pass.
        torch.set_rng_state(ctx.fwd_cpu_rng_state)
349
        _set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False)
350
351
        if get_rng_state_tracker is not None:
            get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
Przemek Tredak's avatar
Przemek Tredak committed
352
353
354

        # Compute the forward pass.
        detached_inputs = detach_variable(inputs)
355
356
357
358
359
360
361
        with (
            torch.enable_grad(),
            ctx.recompute_ctx,
            ctx.torch_gpu_amp_ctx,
            ctx.torch_cpu_amp_ctx,
            activation_recompute_forward(activation_recompute=True, recompute_phase=True),
        ):
362
            outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
Przemek Tredak's avatar
Przemek Tredak committed
363
364
365

        # Set the states back to what it was at the start of this function.
        torch.set_rng_state(bwd_cpu_rng_state)
366
        _set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False)
367
368
        if get_rng_state_tracker is not None:
            get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker)
Przemek Tredak's avatar
Przemek Tredak committed
369
370
371

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
372
373
374
375
376
377
378
379
380

        outputs_with_grad = []
        args_with_grad = []
        for i, output in enumerate(outputs):
            if torch.is_tensor(output) and output.requires_grad:
                outputs_with_grad.append(output)
                args_with_grad.append(args[i])
        if len(outputs_with_grad) == 0:
            raise RuntimeError(
381
                "none of output has requires_grad=True, this checkpoint() is not necessary"
382
383
384
            )

        torch.autograd.backward(outputs_with_grad, args_with_grad)
Przemek Tredak's avatar
Przemek Tredak committed
385
        grads = tuple(
386
            inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs
Przemek Tredak's avatar
Przemek Tredak committed
387
        )
388
389
        return (None, None, None, None, None, None) + grads

390

391
392
393
394
class _CheckpointFrame:
    """
    Storage frame for forward RNG states and detached activations from the forward recompute.
    """
395
396

    def __init__(self, recompute_fn: Callable, get_rng_state_tracker: Callable):
397
398
399
400
401
402
403
404
405
406
407
        self.recompute_fn = recompute_fn
        self.recomputed = []
        self.count = 0
        self.get_rng_state_tracker = get_rng_state_tracker
        self.fwd_rng_states = None
        self.bwd_rng_states = None

    def cache_rng_states(self, forward=True):
        """Cache fwd/bwd RNG states in the frame to restore later."""
        rng_states = (
            torch.get_rng_state(),
408
            _get_cuda_rng_state(graph_safe=False),
409
410
        )
        if self.get_rng_state_tracker is not None:
411
            rng_states += (self.get_rng_state_tracker().get_states(),)
412
413
414
415
416
417
418
419
420
421
422
423
424
425

        if forward:
            self.fwd_rng_states = rng_states
        else:
            self.bwd_rng_states = rng_states

    def restore_rng_states(self, forward=True):
        """Restore fwd/bwd RNG states that were previously cached into the frame."""
        if forward:
            rng_states = self.fwd_rng_states
        else:
            rng_states = self.bwd_rng_states

        torch.set_rng_state(rng_states[0])
426
        _set_cuda_rng_state(rng_states[1], graph_safe=False)
427
428
429
430
        if self.get_rng_state_tracker is not None:
            self.get_rng_state_tracker().set_states(rng_states[2])


431
432
433
class _recomputation_hook(
    torch.autograd.graph.saved_tensors_hooks
):  # pylint: disable=too-few-public-methods
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
    """torch.autograd hook for packing/unpacking tensors during the activation recompute phase."""

    def __init__(self, frame):

        def pack_hook(x):
            """
            Packing hook for each recomputed activation passed into the `ctx.save_for_backward()`
            call in the forward recomputation.
            """
            frame.recomputed.append(x.detach())
            return x.detach()

        def unpack_hook(x):
            """
            No-op unpack hook that will never be called because the backward pass for the
            forward recomputation is never triggered.
            """
            return x

        super().__init__(pack_hook, unpack_hook)


456
457
458
class _checkpoint_hook(
    torch.autograd.graph.saved_tensors_hooks
):  # pylint: disable=too-few-public-methods
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
    """torch.autograd hook for packing/unpacking tensors during the checkpointed forward pass."""

    def __init__(self, frame, args, kwargs):

        def pack_hook(x):
            """
            Packing hook for each tensor passed into `ctx.save_for_backward()` call in the
            forward pass. Since this is the first forward pass, we discard the tensor and instead
            pack a placeholder tensor index into the autograd engine context.
            """
            del x
            idx = frame.count
            frame.count += 1
            return idx

        def unpack_hook(idx):
            """
            Unpacking hook for each tensor that comes out of the `ctx.saved_tensors` call in the
            backward pass. The first time this is called, the _recomputation_hook will save all the
            activation tensors from `ctx.save_for_backward()` in the forward recomputation into the
            _CheckpointFrame. Subsequent calls will simply return the already recomputed activation
            tensor at the given index of the _CheckpointFrame storage.
            """

            if not frame.recomputed:
                # Store current RNG states in the backward pass
                frame.cache_rng_states(forward=False)

                # Set RNG states to what we saved before the forward pass
                frame.restore_rng_states(forward=True)

                # Recompute the forward pass
                with _recomputation_hook(frame):
                    frame.recompute_fn(*args, **kwargs)

                # Restore RNG states back to the backward pass
                frame.restore_rng_states(forward=False)

            # Return the already recomputed activation tensor at the given index
            activation = frame.recomputed[idx]
            frame.recomputed[idx] = None
            return activation

        super().__init__(pack_hook, unpack_hook)


def use_reentrant_activation_recompute():
    """Returns `True` if activation recompute is using the 'reentrant' method."""
    return _USE_REENTRANT_ACTIVATION_RECOMPUTE


def get_activation_recompute_contexts():
    """Returns context objects for the checkpointed forward pass and the forward recompute phase."""
    forward_ctx = activation_recompute_forward(
        activation_recompute=True,
        recompute_phase=False,
    )
    recompute_ctx = activation_recompute_forward(
        activation_recompute=True,
        recompute_phase=True,
    )
    return forward_ctx, recompute_ctx


523
def has_te_modules(network):
524
    """
525
    Check if there are any Transformer Engine modules in the network.
526
527
528
529
530
    """
    from .module import LayerNorm, RMSNorm
    from .module.base import TransformerEngineBaseModule
    from .attention import UnfusedDotProductAttention, DotProductAttention, MultiheadAttention
    from .transformer import TransformerLayer
531

532
533
534
535
536
537
538
539
540
    te_classes_list = [
        LayerNorm,
        RMSNorm,
        TransformerEngineBaseModule,
        UnfusedDotProductAttention,
        DotProductAttention,
        MultiheadAttention,
        TransformerLayer,
    ]
541
542
543
544
545

    if isinstance(network, torch.nn.Module):
        for module in network.modules():
            if any(isinstance(module, te_class) for te_class in te_classes_list):
                return True
546
        return False
547

548
549
550
    # Cannot check for TE modules inside a custom class/callable that's not a torch.nn.Module,
    # so just assume that it has TE modules just to be safe.
    return True
Przemek Tredak's avatar
Przemek Tredak committed
551

552

553
@torch._disable_dynamo
Przemek Tredak's avatar
Przemek Tredak committed
554
555
556
def checkpoint(
    function: Callable,
    *args: Tuple[torch.Tensor, ...],
557
    **kwargs: Dict[str, Any],
Przemek Tredak's avatar
Przemek Tredak committed
558
) -> Tuple[torch.Tensor, ...]:
559
560
561
562
563
564
565
566
567
568
569
    """
    Checkpoint a part of the model by trading compute for memory. This function is based on
    `torch.utils.checkpoint.checkpoint <https://pytorch.org/docs/stable/checkpoint.html>`_.

    .. warning::

        It is the user's responsibility to ensure identical behavior when calling
        :attr:`function` from the forward and backward pass. If different output is
        produced (e.g. due to global state), then the checkpointed version won't
        be numerically equivalent.

570
571
572
573
    .. warning::
        `use_reentrant=False` does not support early stopping, and will execute the entire forward
        pass for the checkpointed module when recomputing activations in the backward pass.

574
575
576
    Parameters
    ----------
    function: Callable
577
578
            pytorch module used to run the forward and backward passes using
            the specified :attr:`args` and :attr:`kwargs`.
579
580
581
582
583
    distribute_saved_activations: bool, default = False
            if set to `True` and `use_reentrant=True`, first tensor argument is distributed
            across the specified tensor parallel group (`tp_group`) before saving it for the
            backward pass. This has no effect when `use_reentrant=False`.
    get_rng_state_tracker: `Callable`, default = None
584
            python callable which returns an instance of :func:`CudaRNGStatesTracker`.
585
586
587
588
589
    tp_group : ProcessGroup, default = None
            tensor parallel process group. Used only when `distribute_saved_activations=True`
            and `use_reentrant=True`. If `None`, it falls back to the default group.
    use_reentrant : bool, default = True
            perform checkpointing in reentrant mode.
590
591
592
593
594
    args : tuple
            tuple of torch tensors for inputs to :attr:`function`.
    kwargs : dict
            dictionary of string keys for keyword arguments to :attr:`function`.
    """
595
596
597
598
599
600
601
    # Pop out te.distributed.checkpoint() arguments
    global _USE_REENTRANT_ACTIVATION_RECOMPUTE
    _USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs.pop("use_reentrant", True)
    distribute_saved_activations = kwargs.pop("distribute_saved_activations", False)
    tp_group = kwargs.pop("tp_group", None)
    get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None)

602
    # Ensure backward compatibility.
603
604
605
606
607
608
    if (
        len(args) > 3
        and isinstance(args[0], bool)
        and callable(args[1])
        and isinstance(args[2], None | dist_group_type)
    ):
609
610
611
612
        warnings.warn(
            "Passing non-tensor non-keyword arguments is deprecated and support will be removed in "
            "future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and "
            "`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.",
613
614
            DeprecationWarning,
            stacklevel=2,
615
        )
616
617
618
        distribute_saved_activations = args[0]
        get_rng_state_tracker = args[1]
        tp_group = args[2]
619
620
        args = args[3:]

621
622
    # Trigger the native PyTorch checkpoint if the function is not or does not contain a
    # Transformer Engine module.
623
624
625
    context_fn = kwargs.pop("context_fn", noop_context_fn)
    determinism_check = kwargs.pop("determinism_check", "default")
    debug = kwargs.pop("debug", False)
626
    if not has_te_modules(function):
627
628
629
630
631
632
633
        return torch.utils.checkpoint.checkpoint(
            function,
            *args,
            use_reentrant=_USE_REENTRANT_ACTIVATION_RECOMPUTE,
            context_fn=context_fn,
            determinism_check=determinism_check,
            debug=debug,
634
            **kwargs,
635
636
        )

637
638
639
640
641
    # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
    # to scatter/gather activations that we will recompute anyway.
    setattr(function, "fsdp_wrapped", False)
    setattr(function, "fsdp_group", None)

642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
    # Otherwise discard unused te.utils.checkpoint.checkpoint() arguments
    # and execute TE's own checkpointing
    # NOTE: This logic uses the TE checkpoint on all custom callable `function` handles because we
    #       cannot be sure there are no TE modules inside the function. It also means we might run
    #       the TE checkpoint for non-TE modules, so the TE checkpoint has to support a potential
    #       user context function.
    del determinism_check, debug
    if _USE_REENTRANT_ACTIVATION_RECOMPUTE:
        # If saved activations need to be distributed but there is no process group,
        # default to the world group.
        if distribute_saved_activations:
            assert torch.distributed.is_initialized(), "torch.distributed is not initialized."
            tp_group = torch.distributed.GroupMember.WORLD if tp_group is None else tp_group

        return _CheckpointFunction.apply(
            function,
            distribute_saved_activations,
            get_rng_state_tracker,
            tp_group,
            context_fn,
            kwargs,
            *args,
        )

    if distribute_saved_activations:
        warnings.warn(
            "`distribute_saved_activations=True` has no effect when `use_reentrant=False`. "
            "The non-reentrant checkpoint implementation does not manually store forward "
            "inputs for the activation recompute in the backward pass, and instead leverages "
            "the autograd engine's pack/unpack hooks."
        )

    user_forward_ctx, user_recompute_ctx = context_fn()
    te_forward_ctx, te_recompute_ctx = get_activation_recompute_contexts()

677
678
679
    # Preserve the torch autocast contexts from the forward pass during recompute phase.
    torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts()

680
    def recompute_fn(*args, **kwargs):
681
682
683
684
685
686
687
        with (
            torch.autograd.enable_grad(),
            te_recompute_ctx,
            user_recompute_ctx,
            torch_gpu_amp_forward_ctx,
            torch_cpu_amp_forward_ctx,
        ):
688
689
690
691
692
693
            function(*args, **kwargs)

    # Initialize a new checkpoint frame for each new forward pass.
    new_frame = _CheckpointFrame(
        recompute_fn,
        get_rng_state_tracker,
Przemek Tredak's avatar
Przemek Tredak committed
694
    )
695
696
    new_frame.cache_rng_states(forward=True)

697
    with _checkpoint_hook(new_frame, args, kwargs), te_forward_ctx, user_forward_ctx:
698
        out = function(*args, **kwargs)
Przemek Tredak's avatar
Przemek Tredak committed
699

700
    return out
Przemek Tredak's avatar
Przemek Tredak committed
701

702

703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
class CudaRNGStatesTracker:
    """
    For model parallelism, multiple RNG states need to simultaneously exist in order
    to execute operations in or out of the model parallel region. This class keeps
    track of the various RNG states and provides utility methods to maintain them and
    execute parts of the model under a given RNG setting. Using the `add` method, a
    cuda rng state is initialized based on the input `seed` and is assigned to `name`.
    Later, by forking the rng state, we can perform operations and return to our starting
    cuda state.
    """

    def __init__(self):
        # Map from a string name to the cuda rng state.
        self.states_ = {}
        # Seeds are just for book keeping and ensure no seed is set twice.
        self.seeds_ = set()

    def reset(self):
        """
        Set to the initial state (no tracker).
        """
        self.states_ = {}
        self.seeds_ = set()

    def get_states(self) -> Dict[str, torch.Tensor]:
        """
        Get rng states. Copy the dictionary so we have direct pointers
        to the states, not just a pointer to the dictionary.
        """
        states = {}
        for name in self.states_:
            states[name] = self.states_[name]
        return states

    def set_states(self, states: Dict[str, torch.Tensor]) -> None:
        """
        Set the rng states. For efficiency purposes, we do not
        check the size of seed for compatibility.

        states: Dict[str, torch.Tensor]
               A mapping from string names to RNG states.
        """
        self.states_ = states

    def add(self, name: str, seed: int) -> None:
        """
        Adds a new RNG state.

        name: str
             string identifier for the RNG state.
        seed: int
             PyTorch seed for the RNG state.
        """
        # Check seed is not already used.
        if seed in self.seeds_:
            raise Exception(f"seed {seed} already exists")
        self.seeds_.add(seed)
        # Check that state is not already defined.
        if name in self.states_:
            raise Exception(f"cuda rng state {name} already exists")
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779

        if graph_safe_rng_available():
            new_state = _get_cuda_rng_state(clone=True)
            new_state.manual_seed(seed)
            self.states_[name] = new_state
            # Update global states.
            set_all_rng_states(self.states_)
        else:
            # Get the current rng state.
            orig_rng_state = _get_cuda_rng_state()
            # Set the new state and store it.
            torch.cuda.manual_seed(seed)
            self.states_[name] = _get_cuda_rng_state(clone=True)
            # Reset rng state to what it was.
            _set_cuda_rng_state(orig_rng_state)
            # Update global states.
            set_all_rng_states(self.states_)
780
781

    @contextmanager
782
    def fork(self, name: str = "model-parallel-rng"):
783
784
785
786
787
788
789
790
791
792
        """
        Fork the cuda rng state, perform operations, and exit with
        the original state.

        name: str
             string identifier for the RNG state.
        """
        # Check if we have added the state
        if name not in self.states_:
            raise Exception(f"cuda rng state {name} is not added")
793
794
        # Get the reference to current rng state.
        orig_cuda_rng_state = _get_cuda_rng_state()
795
796
797
798
799
800
        # Set rng state to the desired one
        _set_cuda_rng_state(self.states_[name])
        # Do the stuff we wanted to do.
        try:
            yield
        finally:
801
802
803
            # this is redundant with graph-safe API
            if not graph_safe_rng_available():
                self.states_[name] = _get_cuda_rng_state()
804
805
806
807
            # And set the state to the original state we started with.
            _set_cuda_rng_state(orig_cuda_rng_state)


Przemek Tredak's avatar
Przemek Tredak committed
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
def reduce_scatter_along_first_dim(
    input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Reduce-scatter the input tensor across model parallel group."""
    world_size = get_distributed_world_size(tp_group)
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_, None

    dim_size = list(input_.size())
    assert (
        dim_size[0] % world_size == 0
    ), "First dimension of the tensor should be divisible by tensor parallel size"

    dim_size[0] = dim_size[0] // world_size

824
    output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
825
    handle = torch.distributed.reduce_scatter_tensor(
Przemek Tredak's avatar
Przemek Tredak committed
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
        output, input_.contiguous(), group=tp_group, async_op=async_op
    )
    return output, handle


def gather_along_first_dim(
    input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Gather tensors and concatinate along the first dimension."""

    world_size = get_distributed_world_size(tp_group)
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_, None

    dim_size = list(input_.size())
    dim_size[0] = dim_size[0] * world_size

844
    output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
845
    handle = torch.distributed.all_gather_into_tensor(
Przemek Tredak's avatar
Przemek Tredak committed
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
        output, input_.contiguous(), group=tp_group, async_op=async_op
    )

    return output, handle


def allreduce(
    input_: torch.Tensor,
    tp_group: Optional[dist_group_type] = None,
    async_op: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if get_distributed_world_size(tp_group) == 1:
        return input_, None

    # All-reduce.
    handle = torch.distributed.all_reduce(input_, group=tp_group, async_op=async_op)

    return input_, handle
867
868
869
870
871
872
873
874
875
876
877
878
879


def _fsdp_scatter_tensors(
    fsdp_group: dist_group_type,
    *tensors: torch.Tensor,
):
    shapes = []
    if fsdp_group is not None:
        for t in tensors:
            if isinstance(t, torch.Tensor):
                target = t._data if isinstance(t, Float8Tensor) else t
                shapes.append(target.data.shape)
                safely_set_viewless_tensor_data(
880
881
                    target,
                    split_tensor_into_1d_equal_chunks(target.data, fsdp_group, new_buffer=True),
882
883
884
885
886
887
888
889
                )
            else:
                shapes.append(None)
    return shapes


def _fsdp_gather_tensors(
    fsdp_group: dist_group_type,
890
    shapes: List[Tuple[int, ...]],
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
    *tensors: torch.Tensor,
):
    if fsdp_group is not None:
        assert len(shapes) == len(tensors), "Number of tensors and tensor shapes must be equal."
        for s, t in zip(shapes, tensors):
            if isinstance(t, torch.Tensor):
                assert s is not None, "Internal TE error."
                target = t._data if isinstance(t, Float8Tensor) else t
                safely_set_viewless_tensor_data(
                    target, gather_split_1d_tensor(target.data, fsdp_group).view(s)
                )


def _is_te_module(module):
    """
    Check if given module is a Transformer Engine module that requires the TE checkpoint
    implementation for activation recompute.
    """
    from .module import LayerNorm, RMSNorm
    from .module.base import TransformerEngineBaseModule
    from .attention import UnfusedDotProductAttention, DotProductAttention, MultiheadAttention
    from .transformer import TransformerLayer
913

914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
    te_classes_list = [
        LayerNorm,
        RMSNorm,
        TransformerEngineBaseModule,
        UnfusedDotProductAttention,
        DotProductAttention,
        MultiheadAttention,
        TransformerLayer,
    ]
    is_te_module = False
    for te_class in te_classes_list:
        if isinstance(module, te_class):
            is_te_module = True
            break
    return is_te_module


def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
    """
    Inject FSDP process gorup references into FSDP-wrapped TE modules in an FSDP-wrapped root
    module in order to scatter/gather the Fp8 weight copies at the same time FSDP scatters/gathers
    its `FlatParameters`.

    Parameters
    ----------
    fsdp_root: torch.nn.Module
               FSDP-wrapped root module that may contain FSDP-wrapped TE modules.
    """
    assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped."

    # If the root module is a TE module, inject FSDP information into it
    if _is_te_module(fsdp_root.module):
946
947
948
949
950
        if hasattr(fsdp_root, "primary_weights_in_fp8"):
            assert not fsdp_root.primary_weights_in_fp8, (
                "TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
                "Please initialize your model without the te.fp8_model_init(...) context."
            )
951
952
953
954
955
956
957
958
        root_state = _get_module_fsdp_state(fsdp_root)
        assert root_state is not None, "Root module does not have a valid _FSDPState."
        setattr(fsdp_root.module, "fsdp_group", root_state.process_group)

    # Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules
    fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root)
    for state, fsdp_module in zip(fsdp_states, fsdp_modules):
        if _is_te_module(fsdp_module.module):
959
960
961
962
963
            if hasattr(fsdp_module.module, "primary_weights_in_fp8"):
                assert not fsdp_module.module.primary_weights_in_fp8, (
                    "TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
                    "Please initialize your model without the te.fp8_model_init(...) context."
                )
964
965
966
967
968
969
970
971
972
973
974
975
976
            setattr(fsdp_module.module, "fsdp_group", state.process_group)


class FullyShardedDataParallel(FSDP):
    """
    Transformer Engine wrapper around `torch.distributed.fsdp.FullyShardedDataParallel` that
    extracts necessary information out of the FSDP wrap for TE modules to scatter their
    activation tensors after each forward pass and gather them before the backward pass.
    """

    def __init__(self, module, *args, **kwargs):
        super().__init__(module, *args, **kwargs)
        prepare_te_modules_for_fsdp(self)