distributed.py 56.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
#
# See LICENSE for license information.

"""Methods needed for distributed training (DP/TP)."""
6
7
from __future__ import annotations

8
from collections.abc import Iterable
9
from contextlib import contextmanager, AbstractContextManager, ContextDecorator
10
from functools import lru_cache
11
import math
12
13
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
14

Przemek Tredak's avatar
Przemek Tredak committed
15
import torch
16
from torch.cuda import _lazy_call, _lazy_init
17
from torch.utils.checkpoint import detach_variable, noop_context_fn
18
19
20
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
Przemek Tredak's avatar
Przemek Tredak committed
21

22
23
24
25
26
from .utils import (
    is_non_tn_fp8_gemm_supported,
    safely_set_viewless_tensor_data,
    needs_quantized_gemm,
)
Przemek Tredak's avatar
Przemek Tredak committed
27
from .constants import dist_group_type
28
from .fp8 import FP8GlobalStateManager, fp8_autocast
29
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
30
from .tensor.mxfp8_tensor import MXFP8Quantizer
31
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
32
33
34
from .tensor.quantized_tensor import QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase
from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
35
from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
36
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
37

38
39
40
41
42
43
44
try:
    import torch.distributed._symmetric_memory as symm_mem

    HAS_TORCH_SYMMETRIC = True
except ImportError:
    HAS_TORCH_SYMMETRIC = False

45
46
47
__all__ = ["checkpoint", "CudaRNGStatesTracker"]


Przemek Tredak's avatar
Przemek Tredak committed
48
49
50
51
52
53
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
    "tensor_model_parallel": False,
    "partition_dim": -1,
    "partition_stride": 1,
}

54
55
_USE_REENTRANT_ACTIVATION_RECOMPUTE = True

56
57
58
_FP8_ACTIVATION_RECOMPUTE_ENABLED = False
_FP8_ACTIVATION_RECOMPUTE_PHASE = False

Przemek Tredak's avatar
Przemek Tredak committed
59

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
_ALL_ACTIVE_RNG_STATES = {}


def get_all_rng_states() -> bool:
    """Returns all generator states used by `CudaRNGStatesTracker`."""
    return _ALL_ACTIVE_RNG_STATES


def set_all_rng_states(states: List) -> None:
    """Updates all generator states used by `CudaRNGStatesTracker`."""
    global _ALL_ACTIVE_RNG_STATES
    _ALL_ACTIVE_RNG_STATES = states


def graph_safe_rng_available() -> bool:
    """Returns whether cuda graph safe RNG state manipulation is supported."""
76
77
78
79
80
81
    return (
        hasattr(torch.cuda.CUDAGraph, "register_generator_state")
        and hasattr(torch.Generator, "graphsafe_set_state")
        and hasattr(torch.Generator, "graphsafe_get_state")
        and hasattr(torch.Generator, "clone_state")
    )
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111


def _get_cuda_rng_state(
    device: Union[int, str, torch.device] = "cuda",
    clone: bool = False,
    graph_safe: bool = True,
) -> torch.Tensor:
    """Return the random number generator state of the specified GPU."""

    _lazy_init()
    if isinstance(device, str):
        device = torch.device(device)
    elif isinstance(device, int):
        device = torch.device("cuda", device)
    idx = device.index
    if idx is None:
        idx = torch.cuda.current_device()
    default_generator = torch.cuda.default_generators[idx]
    if graph_safe_rng_available() and graph_safe:
        if clone:
            # Reference to the cloned generator state
            return default_generator.clone_state()
        # Reference to the current generator state
        return default_generator.graphsafe_get_state()
    return default_generator.get_state()


def _set_cuda_rng_state(
    new_state: torch.Tensor,
    device: Union[int, str] = -1,
112
    graph_safe=True,
113
114
) -> None:
    """Sets the random number generator state of the current GPU."""
Przemek Tredak's avatar
Przemek Tredak committed
115
116
117
118
119
120
121
122
123
124
125
126
127

    if device == -1:
        device = torch.device("cuda")
    elif isinstance(device, str):
        device = torch.device(device)
    elif isinstance(device, int):
        device = torch.device("cuda", device)

    def cb() -> None:
        idx = device.index
        if idx is None:
            idx = torch.cuda.current_device()
        default_generator = torch.cuda.default_generators[idx]
128
129
130
        if graph_safe_rng_available() and graph_safe:
            default_generator.graphsafe_set_state(new_state)
            return
Przemek Tredak's avatar
Przemek Tredak committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        default_generator.set_state(new_state)

    _lazy_call(cb)


def set_tensor_model_parallel_attributes(
    tensor: torch.Tensor, is_parallel: bool, dim: int, stride: int
) -> None:
    """set attributes needed for TP"""
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        assert not hasattr(tensor, attribute)
    # Set the attributes.
    setattr(tensor, "tensor_model_parallel", is_parallel)
    setattr(tensor, "partition_dim", dim)
    setattr(tensor, "partition_stride", stride)


148
@lru_cache
Przemek Tredak's avatar
Przemek Tredak committed
149
150
def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int:
    """Return world size for the distributed group."""
151
    if not torch.distributed.is_initialized():
Przemek Tredak's avatar
Przemek Tredak committed
152
153
154
155
        return 1
    return torch.distributed.get_world_size(group=group)


156
@lru_cache
Przemek Tredak's avatar
Przemek Tredak committed
157
158
def get_distributed_rank(group: Optional[dist_group_type] = None) -> int:
    """Return my rank for the distributed group."""
159
    assert torch.distributed.is_initialized(), "torch.distributed is not initialized."
Przemek Tredak's avatar
Przemek Tredak committed
160
161
162
163
164
165
166
    return torch.distributed.get_rank(group=group)


def initialize_affine_weight_gpu(
    weight: torch.Tensor,
    init_method: Callable,
    get_rng_state_tracker: Callable,
167
    partition_dim: int = 0,
Przemek Tredak's avatar
Przemek Tredak committed
168
    stride: int = 1,
169
    set_tp_attributes: bool = True,
Przemek Tredak's avatar
Przemek Tredak committed
170
171
172
) -> None:
    """Initialize affine weight for model parallel on GPU."""

173
174
175
176
    if set_tp_attributes:
        set_tensor_model_parallel_attributes(
            tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
        )
Przemek Tredak's avatar
Przemek Tredak committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

    if get_rng_state_tracker is None:
        init_method(weight)
        return

    with get_rng_state_tracker().fork():
        init_method(weight)


def split_tensor_into_1d_equal_chunks(
    tensor: torch.Tensor, tp_group: dist_group_type, new_buffer: bool = False
) -> torch.Tensor:
    """Break a tensor into equal 1D chunks."""
    partition_size = torch.numel(tensor) // get_distributed_world_size(tp_group)
    start_index = partition_size * get_distributed_rank(tp_group)
    end_index = start_index + partition_size
    if new_buffer:
        data = torch.empty(
            partition_size,
            dtype=tensor.dtype,
            device=torch.cuda.current_device(),
            requires_grad=False,
        )
        data.copy_(tensor.view(-1)[start_index:end_index])
    else:
        data = tensor.view(-1)[start_index:end_index]
    return data


206
def gather_split_1d_tensor(tensor: torch.Tensor, tp_group: dist_group_type) -> torch.Tensor:
Przemek Tredak's avatar
Przemek Tredak committed
207
208
209
210
211
212
213
214
    """Opposite of above function, gather values from model parallel ranks."""
    numel_gathered = torch.numel(tensor) * get_distributed_world_size(tp_group)
    gathered = torch.empty(
        numel_gathered,
        dtype=tensor.dtype,
        device=torch.cuda.current_device(),
        requires_grad=False,
    )
215
    torch.distributed.all_gather_into_tensor(gathered, tensor, group=tp_group)
Przemek Tredak's avatar
Przemek Tredak committed
216
217
218
    return gathered


219
class activation_recompute_forward(AbstractContextManager, ContextDecorator):
220
221
222
223
224
225
226
    """Context manager used to control the forward runtime behavior when executed
    under the `CheckpointFunction` function. For running FP8, the forward pass will
    run without storing intermediate activations. Instead, the forward pass saves
    the inputs tuple and the calling function. In the backwards pass, these are
    retrieved, and the forward pass is computed again while tracking the intermediate
    activations, followed by calculation of gradients using these values.
    """
227

228
229
    _is_first_fp8_module: List = []

230
    def __init__(self, activation_recompute: bool = False, recompute_phase: bool = False):
231
232
233
234
235
236
        super().__init__()
        self.activation_recompute = activation_recompute
        self.recompute_phase = recompute_phase

    def __enter__(self):
        global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
237
        _FP8_ACTIVATION_RECOMPUTE_ENABLED = (
238
239
240
241
            self.activation_recompute and FP8GlobalStateManager.is_fp8_enabled()
        )
        _FP8_ACTIVATION_RECOMPUTE_PHASE = self.recompute_phase

242
243
244
245
246
247
248
249
250
        if self.activation_recompute and not self.recompute_phase:
            activation_recompute_forward._is_first_fp8_module.append(
                FP8GlobalStateManager.IS_FIRST_FP8_MODULE
            )
        if self.activation_recompute and self.recompute_phase:
            FP8GlobalStateManager.IS_FIRST_FP8_MODULE = (
                activation_recompute_forward._is_first_fp8_module.pop(0)
            )

251
252
    def __exit__(self, *exc_details):
        global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
253
254
255
256
257
258
259
260
261
262
263
264
265
266
        _FP8_ACTIVATION_RECOMPUTE_ENABLED = False
        _FP8_ACTIVATION_RECOMPUTE_PHASE = False


def is_fp8_activation_recompute_enabled() -> bool:
    """Return global boolean"""
    return _FP8_ACTIVATION_RECOMPUTE_ENABLED


def in_fp8_activation_recompute_phase() -> bool:
    """Return global boolean"""
    return _FP8_ACTIVATION_RECOMPUTE_PHASE


267
268
269
270
271
272
273
def _get_active_autocast_contexts():
    """
    Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state
    at the time of this function's execution.
    """
    autocast_cached = torch.is_autocast_cache_enabled()

274
275
    gpu_autocast_enabled = torch.is_autocast_enabled()
    gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
276
    gpu_autocast_ctx = torch.cuda.amp.autocast(
277
278
        gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
    )
279

280
281
    cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
    cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
282
    cpu_autocast_ctx = torch.cpu.amp.autocast(
283
284
        cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
    )
285
286
287
288

    return gpu_autocast_ctx, cpu_autocast_ctx


289
class _CheckpointFunction(torch.autograd.Function):
Przemek Tredak's avatar
Przemek Tredak committed
290
291
292
293
294
295
296
297
298
299
300
301
    """This function is adapted from torch.utils.checkpoint with
    two main changes:
        1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
        2) the states in the model parallel tracker are also properly
           tracked/set/reset.
    """

    @staticmethod
    def forward(
        ctx,
        run_function: Callable,
        distribute_saved_activations: bool,
302
303
304
        get_rng_state_tracker: Union[Callable, None],
        tp_group: Union[dist_group_type, None],
        context_fn: Union[Callable, None],
305
        kwargs: Dict[str, Any],
Przemek Tredak's avatar
Przemek Tredak committed
306
307
        *args: Tuple[torch.Tensor, ...],
    ) -> Tuple[torch.Tensor, ...]:
308
309
        """Call forward function while saving state to be able to
        redo the computation later."""
Przemek Tredak's avatar
Przemek Tredak committed
310
311
312
313
314
        ctx.run_function = run_function
        ctx.distribute_saved_activations = distribute_saved_activations

        # Copy the rng states.
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
315
        ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
316
317
318
319
320
321
322
        if get_rng_state_tracker is not None:
            ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()

        if context_fn is not None:
            forward_ctx, recompute_ctx = context_fn()
        else:
            forward_ctx, recompute_ctx = noop_context_fn()
323
324
325
326

        # Preserve torch autocast context for the backward pass
        torch_gpu_amp_ctx, torch_cpu_amp_ctx = _get_active_autocast_contexts()

327
        with torch.no_grad(), forward_ctx:
328
            with activation_recompute_forward(activation_recompute=True, recompute_phase=False):
329
                outputs = run_function(*args, **kwargs)
Przemek Tredak's avatar
Przemek Tredak committed
330
331
332
333
334
335
336

        # Divide hidden states across model parallel group and only keep
        # the chunk corresponding to the current rank.
        if distribute_saved_activations:
            ctx.input_0_shape = args[0].data.shape
            safely_set_viewless_tensor_data(
                args[0],
337
                split_tensor_into_1d_equal_chunks(args[0].data, tp_group, new_buffer=True),
Przemek Tredak's avatar
Przemek Tredak committed
338
339
340
            )

        # Store everything.
341
342
343
344
        ctx.inputs = [arg if not torch.is_tensor(arg) else None for arg in args]
        tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args]
        ctx.save_for_backward(*tensor_inputs)

345
        fp8 = FP8GlobalStateManager.is_fp8_enabled()
346
        ctx.get_rng_state_tracker = get_rng_state_tracker
Przemek Tredak's avatar
Przemek Tredak committed
347
        ctx.tp_group = tp_group
348
        ctx.recompute_ctx = recompute_ctx
349
350
        ctx.torch_gpu_amp_ctx = torch_gpu_amp_ctx
        ctx.torch_cpu_amp_ctx = torch_cpu_amp_ctx
351
352
        ctx.fp8 = fp8
        ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
353
        ctx.kwargs = kwargs
Przemek Tredak's avatar
Przemek Tredak committed
354
355
356
357
358

        return outputs

    @staticmethod
    def backward(
359
        ctx, *args: Tuple[Union[torch.Tensor, None], ...]
Przemek Tredak's avatar
Przemek Tredak committed
360
    ) -> Tuple[Union[torch.Tensor, None], ...]:
361
        """Call backward function with activation recomputation."""
Przemek Tredak's avatar
Przemek Tredak committed
362
363
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
364
                "Checkpointing is not compatible with .grad(), please use .backward() if possible"
Przemek Tredak's avatar
Przemek Tredak committed
365
            )
366
367

        inputs = tuple(
368
            t if t is not None else arg for (t, arg) in zip(ctx.saved_tensors, ctx.inputs)
369
370
        )

371
        get_rng_state_tracker = ctx.get_rng_state_tracker
Przemek Tredak's avatar
Przemek Tredak committed
372
373
374
375

        if ctx.distribute_saved_activations:
            safely_set_viewless_tensor_data(
                inputs[0],
376
                gather_split_1d_tensor(inputs[0].data, ctx.tp_group).view(ctx.input_0_shape),
Przemek Tredak's avatar
Przemek Tredak committed
377
378
379
380
            )

        # Store the current states.
        bwd_cpu_rng_state = torch.get_rng_state()
381
        bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
382
383
        if get_rng_state_tracker is not None:
            bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
Przemek Tredak's avatar
Przemek Tredak committed
384
385
386

        # Set the states to what it used to be before the forward pass.
        torch.set_rng_state(ctx.fwd_cpu_rng_state)
387
        _set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False)
388
389
        if get_rng_state_tracker is not None:
            get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
Przemek Tredak's avatar
Przemek Tredak committed
390
391
392

        # Compute the forward pass.
        detached_inputs = detach_variable(inputs)
393
394
        with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward(
            activation_recompute=True, recompute_phase=True
395
396
        ), fp8_autocast(
            enabled=ctx.fp8, fp8_recipe=ctx.fp8_recipe
397
        ):
398
            outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
Przemek Tredak's avatar
Przemek Tredak committed
399
400
401

        # Set the states back to what it was at the start of this function.
        torch.set_rng_state(bwd_cpu_rng_state)
402
        _set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False)
403
404
        if get_rng_state_tracker is not None:
            get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker)
Przemek Tredak's avatar
Przemek Tredak committed
405
406
407

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
408
409
410
411
412
413
414
415
416

        outputs_with_grad = []
        args_with_grad = []
        for i, output in enumerate(outputs):
            if torch.is_tensor(output) and output.requires_grad:
                outputs_with_grad.append(output)
                args_with_grad.append(args[i])
        if len(outputs_with_grad) == 0:
            raise RuntimeError(
417
                "none of output has requires_grad=True, this checkpoint() is not necessary"
418
419
            )

420
421
422
        # backward does not require entering autocast context because
        # backward implementations already retrieve fp8 recipe and
        # enablement from stored ctx.
423
        torch.autograd.backward(outputs_with_grad, args_with_grad)
Przemek Tredak's avatar
Przemek Tredak committed
424
        grads = tuple(
425
            inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs
Przemek Tredak's avatar
Przemek Tredak committed
426
        )
427
428
        return (None, None, None, None, None, None) + grads

429

430
431
432
433
class _CheckpointFrame:
    """
    Storage frame for forward RNG states and detached activations from the forward recompute.
    """
434
435

    def __init__(self, recompute_fn: Callable, get_rng_state_tracker: Callable):
436
437
438
439
440
441
442
443
444
445
446
        self.recompute_fn = recompute_fn
        self.recomputed = []
        self.count = 0
        self.get_rng_state_tracker = get_rng_state_tracker
        self.fwd_rng_states = None
        self.bwd_rng_states = None

    def cache_rng_states(self, forward=True):
        """Cache fwd/bwd RNG states in the frame to restore later."""
        rng_states = (
            torch.get_rng_state(),
447
            _get_cuda_rng_state(graph_safe=False),
448
449
        )
        if self.get_rng_state_tracker is not None:
450
            rng_states += (self.get_rng_state_tracker().get_states(),)
451
452
453
454
455
456
457
458
459
460
461
462
463
464

        if forward:
            self.fwd_rng_states = rng_states
        else:
            self.bwd_rng_states = rng_states

    def restore_rng_states(self, forward=True):
        """Restore fwd/bwd RNG states that were previously cached into the frame."""
        if forward:
            rng_states = self.fwd_rng_states
        else:
            rng_states = self.bwd_rng_states

        torch.set_rng_state(rng_states[0])
465
        _set_cuda_rng_state(rng_states[1], graph_safe=False)
466
467
468
469
        if self.get_rng_state_tracker is not None:
            self.get_rng_state_tracker().set_states(rng_states[2])


470
471
472
class _recomputation_hook(
    torch.autograd.graph.saved_tensors_hooks
):  # pylint: disable=too-few-public-methods
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
    """torch.autograd hook for packing/unpacking tensors during the activation recompute phase."""

    def __init__(self, frame):

        def pack_hook(x):
            """
            Packing hook for each recomputed activation passed into the `ctx.save_for_backward()`
            call in the forward recomputation.
            """
            frame.recomputed.append(x.detach())
            return x.detach()

        def unpack_hook(x):
            """
            No-op unpack hook that will never be called because the backward pass for the
            forward recomputation is never triggered.
            """
            return x

        super().__init__(pack_hook, unpack_hook)


495
496
497
class _checkpoint_hook(
    torch.autograd.graph.saved_tensors_hooks
):  # pylint: disable=too-few-public-methods
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    """torch.autograd hook for packing/unpacking tensors during the checkpointed forward pass."""

    def __init__(self, frame, args, kwargs):

        def pack_hook(x):
            """
            Packing hook for each tensor passed into `ctx.save_for_backward()` call in the
            forward pass. Since this is the first forward pass, we discard the tensor and instead
            pack a placeholder tensor index into the autograd engine context.
            """
            del x
            idx = frame.count
            frame.count += 1
            return idx

        def unpack_hook(idx):
            """
            Unpacking hook for each tensor that comes out of the `ctx.saved_tensors` call in the
            backward pass. The first time this is called, the _recomputation_hook will save all the
            activation tensors from `ctx.save_for_backward()` in the forward recomputation into the
            _CheckpointFrame. Subsequent calls will simply return the already recomputed activation
            tensor at the given index of the _CheckpointFrame storage.
            """

            if not frame.recomputed:
                # Store current RNG states in the backward pass
                frame.cache_rng_states(forward=False)

                # Set RNG states to what we saved before the forward pass
                frame.restore_rng_states(forward=True)

                # Recompute the forward pass
                with _recomputation_hook(frame):
                    frame.recompute_fn(*args, **kwargs)

                # Restore RNG states back to the backward pass
                frame.restore_rng_states(forward=False)

            # Return the already recomputed activation tensor at the given index
            activation = frame.recomputed[idx]
            frame.recomputed[idx] = None
            return activation

        super().__init__(pack_hook, unpack_hook)


def use_reentrant_activation_recompute():
    """Returns `True` if activation recompute is using the 'reentrant' method."""
    return _USE_REENTRANT_ACTIVATION_RECOMPUTE


def get_activation_recompute_contexts():
    """Returns context objects for the checkpointed forward pass and the forward recompute phase."""
    forward_ctx = activation_recompute_forward(
        activation_recompute=True,
        recompute_phase=False,
    )
    recompute_ctx = activation_recompute_forward(
        activation_recompute=True,
        recompute_phase=True,
    )
    return forward_ctx, recompute_ctx


562
def has_te_modules(network):
563
    """
564
    Check if there are any Transformer Engine modules in the network.
565
566
567
    """
    from .module import LayerNorm, RMSNorm
    from .module.base import TransformerEngineBaseModule
568
569
570
    from .attention.dot_product_attention.backends import UnfusedDotProductAttention
    from .attention.dot_product_attention.dot_product_attention import DotProductAttention
    from .attention.multi_head_attention import MultiheadAttention
571
    from .transformer import TransformerLayer
572

573
574
575
576
577
578
579
580
581
    te_classes_list = [
        LayerNorm,
        RMSNorm,
        TransformerEngineBaseModule,
        UnfusedDotProductAttention,
        DotProductAttention,
        MultiheadAttention,
        TransformerLayer,
    ]
582
583
584
585
586

    if isinstance(network, torch.nn.Module):
        for module in network.modules():
            if any(isinstance(module, te_class) for te_class in te_classes_list):
                return True
587
        return False
588

589
590
591
    # Cannot check for TE modules inside a custom class/callable that's not a torch.nn.Module,
    # so just assume that it has TE modules just to be safe.
    return True
Przemek Tredak's avatar
Przemek Tredak committed
592

593

594
@torch._disable_dynamo
Przemek Tredak's avatar
Przemek Tredak committed
595
596
597
def checkpoint(
    function: Callable,
    *args: Tuple[torch.Tensor, ...],
598
    **kwargs: Dict[str, Any],
Przemek Tredak's avatar
Przemek Tredak committed
599
) -> Tuple[torch.Tensor, ...]:
600
601
602
603
604
605
606
607
608
609
610
    """
    Checkpoint a part of the model by trading compute for memory. This function is based on
    `torch.utils.checkpoint.checkpoint <https://pytorch.org/docs/stable/checkpoint.html>`_.

    .. warning::

        It is the user's responsibility to ensure identical behavior when calling
        :attr:`function` from the forward and backward pass. If different output is
        produced (e.g. due to global state), then the checkpointed version won't
        be numerically equivalent.

611
612
613
614
    .. warning::
        `use_reentrant=False` does not support early stopping, and will execute the entire forward
        pass for the checkpointed module when recomputing activations in the backward pass.

615
616
617
    Parameters
    ----------
    function: Callable
618
619
            pytorch module used to run the forward and backward passes using
            the specified :attr:`args` and :attr:`kwargs`.
620
621
622
623
624
    distribute_saved_activations: bool, default = False
            if set to `True` and `use_reentrant=True`, first tensor argument is distributed
            across the specified tensor parallel group (`tp_group`) before saving it for the
            backward pass. This has no effect when `use_reentrant=False`.
    get_rng_state_tracker: `Callable`, default = None
625
            python callable which returns an instance of :func:`CudaRNGStatesTracker`.
626
627
628
629
630
    tp_group : ProcessGroup, default = None
            tensor parallel process group. Used only when `distribute_saved_activations=True`
            and `use_reentrant=True`. If `None`, it falls back to the default group.
    use_reentrant : bool, default = True
            perform checkpointing in reentrant mode.
631
632
633
634
635
    args : tuple
            tuple of torch tensors for inputs to :attr:`function`.
    kwargs : dict
            dictionary of string keys for keyword arguments to :attr:`function`.
    """
636
637
638
639
640
641
642
    # Pop out te.distributed.checkpoint() arguments
    global _USE_REENTRANT_ACTIVATION_RECOMPUTE
    _USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs.pop("use_reentrant", True)
    distribute_saved_activations = kwargs.pop("distribute_saved_activations", False)
    tp_group = kwargs.pop("tp_group", None)
    get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None)

643
    # Ensure backward compatibility.
644
645
646
647
648
649
    if (
        len(args) > 3
        and isinstance(args[0], bool)
        and callable(args[1])
        and isinstance(args[2], None | dist_group_type)
    ):
650
651
652
653
        warnings.warn(
            "Passing non-tensor non-keyword arguments is deprecated and support will be removed in "
            "future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and "
            "`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.",
654
655
            DeprecationWarning,
            stacklevel=2,
656
        )
657
658
659
        distribute_saved_activations = args[0]
        get_rng_state_tracker = args[1]
        tp_group = args[2]
660
661
        args = args[3:]

662
663
    # Trigger the native PyTorch checkpoint if the function is not or does not contain a
    # Transformer Engine module.
664
665
666
    context_fn = kwargs.pop("context_fn", noop_context_fn)
    determinism_check = kwargs.pop("determinism_check", "default")
    debug = kwargs.pop("debug", False)
667
    if not has_te_modules(function):
668
669
670
671
672
673
674
        return torch.utils.checkpoint.checkpoint(
            function,
            *args,
            use_reentrant=_USE_REENTRANT_ACTIVATION_RECOMPUTE,
            context_fn=context_fn,
            determinism_check=determinism_check,
            debug=debug,
675
            **kwargs,
676
677
        )

678
679
680
681
682
683
684
    from .module.base import TransformerEngineBaseModule

    if isinstance(function, TransformerEngineBaseModule):
        # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
        # to scatter/gather activations that we will recompute anyway.
        setattr(function, "fsdp_wrapped", False)
        setattr(function, "fsdp_group", None)
685

686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
    # Otherwise discard unused te.utils.checkpoint.checkpoint() arguments
    # and execute TE's own checkpointing
    # NOTE: This logic uses the TE checkpoint on all custom callable `function` handles because we
    #       cannot be sure there are no TE modules inside the function. It also means we might run
    #       the TE checkpoint for non-TE modules, so the TE checkpoint has to support a potential
    #       user context function.
    del determinism_check, debug
    if _USE_REENTRANT_ACTIVATION_RECOMPUTE:
        # If saved activations need to be distributed but there is no process group,
        # default to the world group.
        if distribute_saved_activations:
            assert torch.distributed.is_initialized(), "torch.distributed is not initialized."
            tp_group = torch.distributed.GroupMember.WORLD if tp_group is None else tp_group

        return _CheckpointFunction.apply(
            function,
            distribute_saved_activations,
            get_rng_state_tracker,
            tp_group,
            context_fn,
            kwargs,
            *args,
        )

    if distribute_saved_activations:
        warnings.warn(
            "`distribute_saved_activations=True` has no effect when `use_reentrant=False`. "
            "The non-reentrant checkpoint implementation does not manually store forward "
            "inputs for the activation recompute in the backward pass, and instead leverages "
            "the autograd engine's pack/unpack hooks."
        )

    user_forward_ctx, user_recompute_ctx = context_fn()
    te_forward_ctx, te_recompute_ctx = get_activation_recompute_contexts()

721
722
723
    # Preserve the torch autocast contexts from the forward pass during recompute phase.
    torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts()

724
725
726
    fp8 = FP8GlobalStateManager.is_fp8_enabled()
    fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None

727
    def recompute_fn(*args, **kwargs):
728
729
        with torch.autograd.enable_grad(), (
            te_recompute_ctx
730
731
732
        ), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx, fp8_autocast(
            enabled=fp8, fp8_recipe=fp8_recipe
        ):
733
734
735
736
737
738
            function(*args, **kwargs)

    # Initialize a new checkpoint frame for each new forward pass.
    new_frame = _CheckpointFrame(
        recompute_fn,
        get_rng_state_tracker,
Przemek Tredak's avatar
Przemek Tredak committed
739
    )
740
741
    new_frame.cache_rng_states(forward=True)

742
    with _checkpoint_hook(new_frame, args, kwargs), te_forward_ctx, user_forward_ctx:
743
        out = function(*args, **kwargs)
Przemek Tredak's avatar
Przemek Tredak committed
744

745
    return out
Przemek Tredak's avatar
Przemek Tredak committed
746

747

748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
class CudaRNGStatesTracker:
    """
    For model parallelism, multiple RNG states need to simultaneously exist in order
    to execute operations in or out of the model parallel region. This class keeps
    track of the various RNG states and provides utility methods to maintain them and
    execute parts of the model under a given RNG setting. Using the `add` method, a
    cuda rng state is initialized based on the input `seed` and is assigned to `name`.
    Later, by forking the rng state, we can perform operations and return to our starting
    cuda state.
    """

    def __init__(self):
        # Map from a string name to the cuda rng state.
        self.states_ = {}
        # Seeds are just for book keeping and ensure no seed is set twice.
        self.seeds_ = set()

    def reset(self):
        """
        Set to the initial state (no tracker).
        """
        self.states_ = {}
        self.seeds_ = set()

    def get_states(self) -> Dict[str, torch.Tensor]:
        """
        Get rng states. Copy the dictionary so we have direct pointers
        to the states, not just a pointer to the dictionary.
        """
        states = {}
        for name in self.states_:
            states[name] = self.states_[name]
        return states

    def set_states(self, states: Dict[str, torch.Tensor]) -> None:
        """
        Set the rng states. For efficiency purposes, we do not
        check the size of seed for compatibility.

        states: Dict[str, torch.Tensor]
               A mapping from string names to RNG states.
        """
        self.states_ = states

    def add(self, name: str, seed: int) -> None:
        """
        Adds a new RNG state.

        name: str
             string identifier for the RNG state.
        seed: int
             PyTorch seed for the RNG state.
        """
        # Check seed is not already used.
        if seed in self.seeds_:
803
            raise RuntimeError(f"seed {seed} already exists")
804
805
806
        self.seeds_.add(seed)
        # Check that state is not already defined.
        if name in self.states_:
807
            raise RuntimeError(f"cuda rng state {name} already exists")
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824

        if graph_safe_rng_available():
            new_state = _get_cuda_rng_state(clone=True)
            new_state.manual_seed(seed)
            self.states_[name] = new_state
            # Update global states.
            set_all_rng_states(self.states_)
        else:
            # Get the current rng state.
            orig_rng_state = _get_cuda_rng_state()
            # Set the new state and store it.
            torch.cuda.manual_seed(seed)
            self.states_[name] = _get_cuda_rng_state(clone=True)
            # Reset rng state to what it was.
            _set_cuda_rng_state(orig_rng_state)
            # Update global states.
            set_all_rng_states(self.states_)
825
826

    @contextmanager
827
    def fork(self, name: str = "model-parallel-rng"):
828
829
830
831
832
833
834
835
836
        """
        Fork the cuda rng state, perform operations, and exit with
        the original state.

        name: str
             string identifier for the RNG state.
        """
        # Check if we have added the state
        if name not in self.states_:
837
            raise KeyError(f"cuda rng state {name} is not added")
838
839
        # Get the reference to current rng state.
        orig_cuda_rng_state = _get_cuda_rng_state()
840
841
842
843
844
845
        # Set rng state to the desired one
        _set_cuda_rng_state(self.states_[name])
        # Do the stuff we wanted to do.
        try:
            yield
        finally:
846
847
848
            # this is redundant with graph-safe API
            if not graph_safe_rng_available():
                self.states_[name] = _get_cuda_rng_state()
849
850
851
852
            # And set the state to the original state we started with.
            _set_cuda_rng_state(orig_cuda_rng_state)


Przemek Tredak's avatar
Przemek Tredak committed
853
def reduce_scatter_along_first_dim(
854
    inp: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
855
) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]:
Przemek Tredak's avatar
Przemek Tredak committed
856
857
858
859
    """Reduce-scatter the input tensor across model parallel group."""
    world_size = get_distributed_world_size(tp_group)
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
860
        return inp, None
Przemek Tredak's avatar
Przemek Tredak committed
861

862
    dim_size = list(inp.size())
Przemek Tredak's avatar
Przemek Tredak committed
863
864
865
866
867
868
    assert (
        dim_size[0] % world_size == 0
    ), "First dimension of the tensor should be divisible by tensor parallel size"

    dim_size[0] = dim_size[0] // world_size

869
    output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device())
870
    handle = torch.distributed.reduce_scatter_tensor(
871
        output, inp.contiguous(), group=tp_group, async_op=async_op
Przemek Tredak's avatar
Przemek Tredak committed
872
873
874
875
    )
    return output, handle


876
def _all_gather_fp8(
877
    inp: torch.Tensor,
878
879
880
    process_group: dist_group_type,
    *,
    async_op: bool = False,
881
    quantizer: Optional[Quantizer] = None,
882
883
884
885
886
    out_shape: Optional[list[int]] = None,
) -> tuple[Float8TensorBase, Optional[torch.distributed.Work]]:
    """All-gather FP8 tensor along first dimension."""
    world_size = get_distributed_world_size(process_group)

887
888
889
890
891
892
    # Check that quantizer is valid
    if quantizer is not None and not isinstance(
        quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
    ):
        raise ValueError(f"Got non-FP8 quantizer ({quantizer.__class__.__name__})")

893
894
    # Output tensor dims
    if out_shape is None:
895
        out_shape = list(inp.size())
896
897
        out_shape[0] *= world_size

898
899
900
    # Cast input tensor to FP8 if needed
    # Note: We cannot directly all-gather the transposed FP8 tensor,
    # so temporarily modify quantizer to avoid creating FP8 transpose.
901
    if not isinstance(inp, Float8TensorBase):
902
903
        if quantizer is None:
            raise ValueError("Input tensor is not FP8 and no quantizer was provided")
904
        init_rowwise_usage = quantizer.rowwise_usage
905
        init_columnwise_usage = quantizer.columnwise_usage
906
        quantizer.set_usage(rowwise=True, columnwise=False)
907
        inp = quantizer(inp)
908
909
910
911
        quantizer.set_usage(
            rowwise=init_rowwise_usage,
            columnwise=init_columnwise_usage,
        )
912
913
914

    # Construct output tensor
    out: Float8TensorBase
915
    if quantizer is not None:
916
917
        dtype = torch.float32
        device = "cuda"
918
919
920
        if isinstance(inp, Float8Tensor):
            dtype = inp.dtype
            device = inp.device
921
        out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
922
923
    elif isinstance(inp, Float8Tensor):
        out = inp.make_like(inp, shape=out_shape)
924
925
926
        out._data = torch.empty_like(
            out_shape,
            dtype=torch.uint8,
927
            device=inp.device,
928
929
930
931
932
        )
        out._transpose = None
        out._transpose_invalid = True
    else:
        raise RuntimeError("FP8TensorBase is not supported yet without Quantizer")
933
934

    # Assume scaling factors are identical across ranks
935
    out._scale_inv = inp._scale_inv
936
937
938
939

    # Perform communication
    handle = torch.distributed.all_gather_into_tensor(
        out._data,
940
        inp._data.contiguous(),
941
942
943
944
945
        group=process_group,
        async_op=async_op,
    )

    # Make sure FP8 transpose is populated if needed
946
    needs_transpose = (
947
        quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported()
948
949
    )
    if needs_transpose:
950
951
952
953
954
955
956
957
        if handle is not None:
            handle.wait()
            handle = None
        out._create_transpose()

    return out, handle


958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
def _all_gather_fp8_blockwise(
    inp: torch.Tensor,
    process_group: dist_group_type,
    *,
    async_op: bool = False,  # pylint: disable=unused-argument
    quantizer: Optional[Quantizer] = None,
    out_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]:
    """
    All-gather FP8 tensor along first dimension for blockwise quantization.

    Returns: quantizer(gather(inp))

    NOTE: The implementation is not sophisticated enough to honor async_op=True.
    In some cases it falls back to synchronous gather and invokes the quantizer.
    """

    # Input tensor attributes
    device: torch.device
    dtype: torch.dtype
    if isinstance(inp, torch.Tensor):
        device = inp.device
        dtype = inp.dtype
    elif isinstance(inp, Float8BlockwiseQTensorBase):
        if inp._rowwise_data is not None:
            device = inp._rowwise_data.device
        elif inp._columnwise_data is not None:
            device = inp._columnwise_data.device
        else:
            raise ValueError("Got Float8BlockwiseQTensorBase input tensor without any data")
        dtype = torch.bfloat16  # Only has fp8 dtype. Guess BF16 for dequant.
    else:
        raise ValueError(
            "Invalid type for input tensor (expected torch.Tensor or Float8BlockwiseQTensorBase, "
            f"found {inp.__class__.__name__})"
        )
    world_size = get_distributed_world_size(process_group)

    # Check that quantizer is valid
    if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer):
        raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})")
    if not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128):
        raise NotImplementedError("Only 1D blockwise quantization is supported for allgather")

    # Output tensor dims
    if out_shape is None:
        out_shape = list(inp.size())
        out_shape[0] *= world_size

    # Doing BF16 gather for now as baseline because it's simpler
    if not isinstance(inp, Float8BlockwiseQTensorBase) and quantizer is not None:
        out = torch.empty(
            out_shape,
            dtype=dtype,
            device=device,
            memory_format=torch.contiguous_format,
        )
        torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
        out = quantizer(out)
        return out, None
    # Implementation of fp8 gather needs to account for:
    # * Getting columnwise data as a transpose of how it is stored for GEMMS.
    # * Gathering non GEMM swizzled scales.
    # * Refer to scaffold code when implementing at:
    # https://github.com/kwyss-nvidia/TransformerEngine/commit/6659ee9dc84fb515d1d47699d8bfd20a72b76477
    raise NotImplementedError("fp8 blockwise allgather not yet implemented")


1026
def _all_gather_mxfp8(
1027
    inp: torch.Tensor,
1028
1029
1030
1031
1032
1033
1034
1035
    process_group: dist_group_type,
    *,
    async_op: bool = False,
    quantizer: MXFP8Quantizer,
    out_shape: Optional[list[int]] = None,
) -> tuple[MXFP8TensorBase, Optional[torch.distributed.Work]]:
    """All-gather MXFP8 tensor along first dimension."""

1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
    # Input tensor attributes
    in_shape: Iterable[int]
    device: torch.device
    dtype: torch.dtype
    if isinstance(inp, torch.Tensor):
        in_shape = inp.size()
        device = inp.device
        dtype = inp.dtype
    elif isinstance(inp, MXFP8TensorBase):
        if inp._rowwise_data is not None:
            in_shape = inp._rowwise_data.device.size()
            device = inp._rowwise_data.device
            dtype = inp._rowwise_data.dtype
        elif inp._columnwise_data is not None:
            in_shape = inp._columnwise_data.device.size()
            device = inp._columnwise_data.device
            dtype = inp._columnwise_data.dtype
        else:
            raise ValueError("Got MXFP8 input tensor without any data")
        dtype = torch.bfloat16
    else:
        raise ValueError(
            "Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, "
            f"found {inp.__class__.__name__})"
        )

    # Output tensor shape
1063
1064
1065
1066
    world_size = get_distributed_world_size(process_group)
    if out_shape is None:
        out_shape = [in_shape[0] * world_size] + in_shape[1:]

1067
1068
1069
1070
1071
1072
1073
1074
1075
    # For cases where inp has dimensions that cannot be quantized,
    # we gather in high precision followed by a cast to FP8.
    if (
        not isinstance(inp, MXFP8TensorBase)
        and quantizer is not None
        and not quantizer.is_quantizable(inp)
    ):
        out = torch.empty(
            out_shape,
1076
1077
            dtype=dtype,
            device=device,
1078
1079
1080
1081
1082
1083
1084
1085
1086
            memory_format=torch.contiguous_format,
        )
        torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
        out = quantizer(out)
        return out, None

    # Cast input tensor to MXFP8 with required data
    if not isinstance(inp, MXFP8TensorBase):
        inp = quantizer(inp)
1087
1088
    elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
        quantizer.columnwise_usage and inp._columnwise_data is None
1089
1090
1091
1092
1093
1094
    ):
        warnings.warn(
            "Input and quantizer do not have matching usages. "
            "Dequantizing and requantizing to MXFP8."
        )
        inp = quantizer(inp.dequantize())
1095

1096
    # Construct MXFP8 output tensor
1097
    out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
1098

1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
    # Coalesce NCCL collectives
    with torch.distributed._coalescing_manager(
        group=process_group,
        device=device,
        async_ops=async_op,
    ) as coalescing_manager:

        # Gather MXFP8 data for row-wise usage
        if quantizer.rowwise_usage:

            # Remove padding from MXFP8 scale-inverses
            in_scale_inv = inp._rowwise_scale_inv
            out_scale_inv = out._rowwise_scale_inv
            flattened_in_shape0 = math.prod(in_shape[:-1])
            if in_scale_inv.size(0) != flattened_in_shape0:
                in_scale_inv = in_scale_inv[:flattened_in_shape0]
                out_scale_inv[flattened_in_shape0 * world_size :].zero_()
                out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]

            # Launch all-gathers
            torch.distributed.all_gather_into_tensor(
                out_scale_inv,
                in_scale_inv,
                group=process_group,
            )
            torch.distributed.all_gather_into_tensor(
                out._rowwise_data,
                inp._rowwise_data,
                group=process_group,
            )
1129

1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
        # Gather MXFP8 data for column-wise usage
        if quantizer.columnwise_usage:

            # Remove padding from MXFP8 scale-inverses
            in_scale_inv = inp._columnwise_scale_inv
            out_scale_inv = out._columnwise_scale_inv
            flattened_in_shape0 = math.prod(in_shape[:-1]) // 32
            if in_scale_inv.size(0) != flattened_in_shape0:
                in_scale_inv = in_scale_inv[:flattened_in_shape0]
                out_scale_inv[flattened_in_shape0 * world_size :].zero_()
                out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]

            # Launch all-gathers
            torch.distributed.all_gather_into_tensor(
                out_scale_inv,
                in_scale_inv,
                group=process_group,
            )
            torch.distributed.all_gather_into_tensor(
                out._columnwise_data,
                inp._columnwise_data,
                group=process_group,
            )
1153

1154
    handle = coalescing_manager if async_op else None
1155
    return out, handle
1156
1157


Przemek Tredak's avatar
Przemek Tredak committed
1158
def gather_along_first_dim(
1159
    inp: torch.Tensor,
1160
1161
    process_group: dist_group_type,
    async_op: bool = False,
1162
1163
    quantizer: Optional[Quantizer] = None,
) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]:
1164
1165
1166
    """
    All-gather tensors and concatenate along first dimension.
    """
Przemek Tredak's avatar
Przemek Tredak committed
1167

1168
1169
    # Return immediately if no communication is required
    world_size = get_distributed_world_size(process_group)
Przemek Tredak's avatar
Przemek Tredak committed
1170
    if world_size == 1:
1171
1172
1173
        if quantizer is not None and not isinstance(inp, QuantizedTensor):
            inp = quantizer(inp)
        return inp, None
Przemek Tredak's avatar
Przemek Tredak committed
1174

1175
    # Output tensor dims
1176
    out_shape = list(inp.size())
1177
1178
    out_shape[0] *= world_size

1179
    # FP8 case: delayed scaling or current scaling
1180
    if isinstance(inp, Float8TensorBase) or isinstance(
1181
1182
        quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
    ):
1183
        return _all_gather_fp8(
1184
            inp,
1185
1186
1187
1188
            process_group,
            async_op=async_op,
            quantizer=quantizer,
            out_shape=out_shape,
1189
        )
1190

1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
    # FP8 block scaling case, block length = 128
    if isinstance(inp, Float8BlockwiseQTensorBase) or isinstance(quantizer, Float8BlockQuantizer):
        return _all_gather_fp8_blockwise(
            inp,
            process_group,
            async_op=async_op,
            quantizer=quantizer,
            out_shape=out_shape,
        )

1201
    # MXFP8 case
1202
    if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer):
1203
1204
        assert isinstance(quantizer, MXFP8Quantizer)
        return _all_gather_mxfp8(
1205
            inp,
1206
1207
1208
1209
1210
1211
            process_group,
            async_op=async_op,
            quantizer=quantizer,
            out_shape=out_shape,
        )

1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
    # Debug case - call gather_along_first_dim on each tensor
    if isinstance(inp, DebugQuantizedTensor):
        out_obj = inp
        rowwise = inp.get_tensor(False)
        columnwise = inp.get_tensor(True)
        final_quantizer = (
            None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer
        )
        rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0]
        out_obj.rowwise_gemm_tensor = rowwise_total
        if rowwise is not columnwise:
            final_quantizer_columnwise = (
                None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer
            )
            columnwise_total, _ = gather_along_first_dim(
                columnwise, process_group, False, final_quantizer_columnwise
            )
            out_obj.columnwise_gemm_tensor = columnwise_total
        else:
            out_obj.rowwise_gemm_tensor = out_obj.rowwise_gemm_tensor
        return out_obj, None

1234
1235
1236
1237
1238
1239
    # High-precision communication for quantized tensors
    if quantizer is not None:
        warnings.warn(
            "Attempting to all-gather an unsupported quantized tensor. "
            "Falling back to high-precision all-gather."
        )
1240
1241
        if isinstance(inp, QuantizedTensor):
            inp = inp.dequantize()
1242
1243
        out = torch.empty(
            out_shape,
1244
1245
            dtype=inp.dtype,
            device=inp.device,
1246
1247
            memory_format=torch.contiguous_format,
        )
1248
        torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
1249
1250
        out = quantizer(out)
        return out, None
Przemek Tredak's avatar
Przemek Tredak committed
1251

1252
    # Dequantize quantized tensor if not supported
1253
    if isinstance(inp, QuantizedTensor):
1254
1255
1256
1257
        warnings.warn(
            "Attempting to all-gather an unsupported quantized tensor. "
            "Falling back to high-precision all-gather."
        )
1258
        inp = inp.dequantize()
1259
1260
1261
1262

    # Communication for plain PyTorch tensors
    out = torch.empty(
        out_shape,
1263
1264
        dtype=inp.dtype,
        device=inp.device,
1265
1266
        memory_format=torch.contiguous_format,
    )
1267
    handle = torch.distributed.all_gather_into_tensor(
1268
        out,
1269
        inp.contiguous(),
1270
1271
        group=process_group,
        async_op=async_op,
Przemek Tredak's avatar
Przemek Tredak committed
1272
    )
1273
    return out, handle
Przemek Tredak's avatar
Przemek Tredak committed
1274
1275


1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
# Global cache to store symmetric memory tensors
symmetric_mem_cache = {}


def get_symmetric_memory_tensor(tensor_numel, tensor_dtype, tensor_device, tp_group, tag=None):
    """
    Gets or creates a symmetric memory tensor with specified properties.

    Reuses cached tensors when available to avoid redundant creation and rendezvous operations.

    Note: This function always returns a 1D tensor.

    Parameters
    ----------
    tensor_numel : int
        Number of elements in the tensor.
    tensor_dtype : torch.dtype
        Data type of the tensor.
    tensor_device : torch.device
        Device on which to allocate the tensor.
    tp_group : dist_group_type
        Process group for rendezvous operation.
    tag : Any, optional
        Optional identifier to further distinguish tensors.

    Returns
    -------
    torch.Tensor
        A symmetric memory tensor with the specified properties.
    """
    # Create a cache key based on tensor properties and group
    cache_key = (tensor_numel, tensor_dtype, tensor_device, tp_group.group_name, tag)

    # Check if we already have a symmetric memory tensor for this configuration
    if cache_key not in symmetric_mem_cache:
        # Create a new symmetric memory tensor if not in cache
        msg = symm_mem.empty(
            tensor_numel,
            dtype=tensor_dtype,
            device=tensor_device,
        )
        # Perform the rendezvous once for this tensor
        symm_mem.rendezvous(msg, group=tp_group)
        # Store in cache
        symmetric_mem_cache[cache_key] = msg
    else:
        # Reuse the existing symmetric memory tensor
        msg = symmetric_mem_cache[cache_key]

    return msg


def symmetric_all_reduce(
    inp: torch.Tensor,
    tp_group: Optional[dist_group_type] = None,
    async_op: bool = False,
    all_reduce_type: str = "multimem_all_reduce",
):
    """
    Performs an all-reduce operation across multiple processes using symmetric memory.
    If the input tensor is already in the symmetric memory cache we can avoid copy
    overheads by just directly using the input tensor for all reduce.  Externally
    created symmetric memory tensors not in the cache currently will not be able to
    avoid the extra copies.

    Parameters
    ----------
    inp : torch.Tensor
        The input tensor to be reduced. The operation is performed in-place.

    tp_group : Optional[dist_group_type], default=None
        The process group over which to perform the all-reduce operation.
        If None, the default process group is used.

    async_op : bool, default=False
        Whether to perform the operation asynchronously.
        Note: Currently only synchronous operations are supported for symmetric memory variants.

    all_reduce_type : str, default="multimem_all_reduce"
        The type of all-reduce implementation to use. Options include:
        - "nccl": Standard PyTorch distributed all-reduce
        - "multimem_all_reduce": multimem symmetric all-reduce
        - "two_shot": Two-shot symmetric all-reduce
        - "one_shot": One-shot symmetric all-reduce

    Returns
    -------
    Tuple[torch.Tensor, Optional[torch.distributed.Work]]
        - The first element is the input tensor with the all-reduce result.
        - The second element is the async work handle if async_op=True,
          otherwise None.
    """
    assert async_op is False, "Async symmetric ops no supported yet"
    assert HAS_TORCH_SYMMETRIC, "Could not import symetric memory from torch"

    if get_distributed_world_size(tp_group) == 1:
        return inp, None

    if all_reduce_type == "nccl":
        # Standard all-reduce implementation
        handle = torch.distributed.all_reduce(inp, group=tp_group, async_op=async_op)
        return inp, handle

    all_reduce_impl = None
    if all_reduce_type == "multimem_all_reduce":
        all_reduce_impl = torch.ops.symm_mem.multimem_all_reduce_
    elif all_reduce_type == "two_shot":
        all_reduce_impl = torch.ops.symm_mem.two_shot_all_reduce_
    elif all_reduce_type == "one_shot":
        all_reduce_impl = torch.ops.symm_mem.one_shot_all_reduce
    else:
        raise TypeError(f"All reduce type {all_reduce_type} is not supported.")

    group_name = tp_group.group_name
    tensor_shape = inp.shape
    tensor_numel = inp.numel()
    tensor_dtype = inp.dtype
    tensor_device = inp.device

    input_id = id(inp)
    is_cached = any(id(cached_tensor) == input_id for cached_tensor in symmetric_mem_cache.values())
    # Check if the input tensor is already in the symmetric memory cache. If it is we can avoid copy overheads.
    if is_cached:
        all_reduce_impl(
            inp,
            "sum",
            group_name,
        )
    else:
        # Get symmetric memory tensor. Build or retrieve from cache.
        msg = get_symmetric_memory_tensor(tensor_numel, tensor_dtype, tensor_device, tp_group)

        msg.copy_(inp.reshape(-1))

        all_reduce_impl(
            msg,
            "sum",
            group_name,
        )

        # Copy the result back to the input tensor
        inp.copy_(msg.reshape(tensor_shape))

    return inp, None


Przemek Tredak's avatar
Przemek Tredak committed
1422
def allreduce(
1423
    inp: torch.Tensor,
Przemek Tredak's avatar
Przemek Tredak committed
1424
1425
    tp_group: Optional[dist_group_type] = None,
    async_op: bool = False,
1426
) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]:
Przemek Tredak's avatar
Przemek Tredak committed
1427
1428
1429
1430
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if get_distributed_world_size(tp_group) == 1:
1431
        return inp, None
Przemek Tredak's avatar
Przemek Tredak committed
1432
1433

    # All-reduce.
1434
    handle = torch.distributed.all_reduce(inp, group=tp_group, async_op=async_op)
Przemek Tredak's avatar
Przemek Tredak committed
1435

1436
    return inp, handle
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446


def _fsdp_scatter_tensors(
    fsdp_group: dist_group_type,
    *tensors: torch.Tensor,
):
    shapes = []
    if fsdp_group is not None:
        for t in tensors:
            if isinstance(t, torch.Tensor):
1447
1448
1449
1450
1451
1452
1453
                targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t]
                for target in targets:
                    shapes.append(target.data.shape)
                    safely_set_viewless_tensor_data(
                        target,
                        split_tensor_into_1d_equal_chunks(target.data, fsdp_group, new_buffer=True),
                    )
1454
1455
1456
1457
1458
1459
1460
            else:
                shapes.append(None)
    return shapes


def _fsdp_gather_tensors(
    fsdp_group: dist_group_type,
1461
    shapes: List[Tuple[int, ...]],
1462
1463
1464
1465
1466
1467
1468
    *tensors: torch.Tensor,
):
    if fsdp_group is not None:
        assert len(shapes) == len(tensors), "Number of tensors and tensor shapes must be equal."
        for s, t in zip(shapes, tensors):
            if isinstance(t, torch.Tensor):
                assert s is not None, "Internal TE error."
1469
1470
1471
1472
1473
                targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t]
                for target in targets:
                    safely_set_viewless_tensor_data(
                        target, gather_split_1d_tensor(target.data, fsdp_group).view(s)
                    )
1474
1475
1476
1477
1478
1479
1480
1481
1482


def _is_te_module(module):
    """
    Check if given module is a Transformer Engine module that requires the TE checkpoint
    implementation for activation recompute.
    """
    from .module import LayerNorm, RMSNorm
    from .module.base import TransformerEngineBaseModule
1483
1484
1485
    from .attention.dot_product_attention.dot_product_attention import DotProductAttention
    from .attention.dot_product_attention.backends import UnfusedDotProductAttention
    from .attention.multi_head_attention import MultiheadAttention
1486
    from .transformer import TransformerLayer
1487

1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
    te_classes_list = [
        LayerNorm,
        RMSNorm,
        TransformerEngineBaseModule,
        UnfusedDotProductAttention,
        DotProductAttention,
        MultiheadAttention,
        TransformerLayer,
    ]
    is_te_module = False
    for te_class in te_classes_list:
        if isinstance(module, te_class):
            is_te_module = True
            break
    return is_te_module


def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
    """
    Inject FSDP process gorup references into FSDP-wrapped TE modules in an FSDP-wrapped root
    module in order to scatter/gather the Fp8 weight copies at the same time FSDP scatters/gathers
    its `FlatParameters`.

    Parameters
    ----------
    fsdp_root: torch.nn.Module
               FSDP-wrapped root module that may contain FSDP-wrapped TE modules.
    """
    assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped."

    # If the root module is a TE module, inject FSDP information into it
    if _is_te_module(fsdp_root.module):
1520
1521
1522
1523
1524
        if hasattr(fsdp_root, "primary_weights_in_fp8"):
            assert not fsdp_root.primary_weights_in_fp8, (
                "TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
                "Please initialize your model without the te.fp8_model_init(...) context."
            )
1525
1526
1527
1528
1529
1530
1531
1532
        root_state = _get_module_fsdp_state(fsdp_root)
        assert root_state is not None, "Root module does not have a valid _FSDPState."
        setattr(fsdp_root.module, "fsdp_group", root_state.process_group)

    # Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules
    fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root)
    for state, fsdp_module in zip(fsdp_states, fsdp_modules):
        if _is_te_module(fsdp_module.module):
1533
1534
1535
1536
1537
            if hasattr(fsdp_module.module, "primary_weights_in_fp8"):
                assert not fsdp_module.module.primary_weights_in_fp8, (
                    "TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
                    "Please initialize your model without the te.fp8_model_init(...) context."
                )
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
            setattr(fsdp_module.module, "fsdp_group", state.process_group)


class FullyShardedDataParallel(FSDP):
    """
    Transformer Engine wrapper around `torch.distributed.fsdp.FullyShardedDataParallel` that
    extracts necessary information out of the FSDP wrap for TE modules to scatter their
    activation tensors after each forward pass and gather them before the backward pass.
    """

    def __init__(self, module, *args, **kwargs):
        super().__init__(module, *args, **kwargs)
        prepare_te_modules_for_fsdp(self)