distributed.py 73.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
#
# See LICENSE for license information.

"""Methods needed for distributed training (DP/TP)."""
6
7
from __future__ import annotations

8
from collections.abc import Iterable
9
from contextlib import contextmanager, AbstractContextManager, ContextDecorator
10
from functools import lru_cache
11
from dataclasses import dataclass
12
import math
13
14
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
15

Przemek Tredak's avatar
Przemek Tredak committed
16
import torch
17
from torch.cuda import _lazy_call, _lazy_init
18
from torch.utils.checkpoint import detach_variable, noop_context_fn
19
20
21
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
Przemek Tredak's avatar
Przemek Tredak committed
22

23
24
25
26
27
28
29
30
31
try:
    import torch.distributed._symmetric_memory as symm_mem

    HAS_TORCH_SYMMETRIC = True
except ImportError:
    HAS_TORCH_SYMMETRIC = False

import transformer_engine_torch as tex

32
from . import torch_version
33
34
35
36
37
from .utils import (
    is_non_tn_fp8_gemm_supported,
    safely_set_viewless_tensor_data,
    needs_quantized_gemm,
)
Przemek Tredak's avatar
Przemek Tredak committed
38
from .constants import dist_group_type
39
from .quantization import FP8GlobalStateManager, autocast
40
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
41
from .tensor.mxfp8_tensor import MXFP8Quantizer
42
from .tensor.nvfp4_tensor import NVFP4Quantizer
43
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
44
45
46
47
48
from .tensor.quantized_tensor import QuantizedTensorStorage, QuantizedTensor, Quantizer
from .tensor.storage.float8_tensor_storage import Float8TensorStorage
from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
from .tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
49
from .triton.pad import pad_columnwise_scale_inv
50
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer
51
52


53
54
55
__all__ = ["checkpoint", "CudaRNGStatesTracker"]


Przemek Tredak's avatar
Przemek Tredak committed
56
57
58
59
60
61
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
    "tensor_model_parallel": False,
    "partition_dim": -1,
    "partition_stride": 1,
}

62
63
_USE_REENTRANT_ACTIVATION_RECOMPUTE = True

64
65
66
_FP8_ACTIVATION_RECOMPUTE_ENABLED = False
_FP8_ACTIVATION_RECOMPUTE_PHASE = False

Przemek Tredak's avatar
Przemek Tredak committed
67

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
_ALL_ACTIVE_RNG_STATES = {}


def get_all_rng_states() -> bool:
    """Returns all generator states used by `CudaRNGStatesTracker`."""
    return _ALL_ACTIVE_RNG_STATES


def set_all_rng_states(states: List) -> None:
    """Updates all generator states used by `CudaRNGStatesTracker`."""
    global _ALL_ACTIVE_RNG_STATES
    _ALL_ACTIVE_RNG_STATES = states


def graph_safe_rng_available() -> bool:
    """Returns whether cuda graph safe RNG state manipulation is supported."""
84
85
86
87
88
89
    return (
        hasattr(torch.cuda.CUDAGraph, "register_generator_state")
        and hasattr(torch.Generator, "graphsafe_set_state")
        and hasattr(torch.Generator, "graphsafe_get_state")
        and hasattr(torch.Generator, "clone_state")
    )
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119


def _get_cuda_rng_state(
    device: Union[int, str, torch.device] = "cuda",
    clone: bool = False,
    graph_safe: bool = True,
) -> torch.Tensor:
    """Return the random number generator state of the specified GPU."""

    _lazy_init()
    if isinstance(device, str):
        device = torch.device(device)
    elif isinstance(device, int):
        device = torch.device("cuda", device)
    idx = device.index
    if idx is None:
        idx = torch.cuda.current_device()
    default_generator = torch.cuda.default_generators[idx]
    if graph_safe_rng_available() and graph_safe:
        if clone:
            # Reference to the cloned generator state
            return default_generator.clone_state()
        # Reference to the current generator state
        return default_generator.graphsafe_get_state()
    return default_generator.get_state()


def _set_cuda_rng_state(
    new_state: torch.Tensor,
    device: Union[int, str] = -1,
120
    graph_safe=True,
121
122
) -> None:
    """Sets the random number generator state of the current GPU."""
Przemek Tredak's avatar
Przemek Tredak committed
123
124
125
126
127
128
129
130
131
132
133
134
135

    if device == -1:
        device = torch.device("cuda")
    elif isinstance(device, str):
        device = torch.device(device)
    elif isinstance(device, int):
        device = torch.device("cuda", device)

    def cb() -> None:
        idx = device.index
        if idx is None:
            idx = torch.cuda.current_device()
        default_generator = torch.cuda.default_generators[idx]
136
137
138
        if graph_safe_rng_available() and graph_safe:
            default_generator.graphsafe_set_state(new_state)
            return
Przemek Tredak's avatar
Przemek Tredak committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        default_generator.set_state(new_state)

    _lazy_call(cb)


def set_tensor_model_parallel_attributes(
    tensor: torch.Tensor, is_parallel: bool, dim: int, stride: int
) -> None:
    """set attributes needed for TP"""
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        assert not hasattr(tensor, attribute)
    # Set the attributes.
    setattr(tensor, "tensor_model_parallel", is_parallel)
    setattr(tensor, "partition_dim", dim)
    setattr(tensor, "partition_stride", stride)


156
@lru_cache
Przemek Tredak's avatar
Przemek Tredak committed
157
158
def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int:
    """Return world size for the distributed group."""
159
    if not torch.distributed.is_initialized():
Przemek Tredak's avatar
Przemek Tredak committed
160
161
162
163
        return 1
    return torch.distributed.get_world_size(group=group)


164
@lru_cache
Przemek Tredak's avatar
Przemek Tredak committed
165
166
def get_distributed_rank(group: Optional[dist_group_type] = None) -> int:
    """Return my rank for the distributed group."""
167
    assert torch.distributed.is_initialized(), "torch.distributed is not initialized."
Przemek Tredak's avatar
Przemek Tredak committed
168
169
170
171
172
173
174
    return torch.distributed.get_rank(group=group)


def initialize_affine_weight_gpu(
    weight: torch.Tensor,
    init_method: Callable,
    get_rng_state_tracker: Callable,
175
    partition_dim: int = 0,
Przemek Tredak's avatar
Przemek Tredak committed
176
    stride: int = 1,
177
    set_tp_attributes: bool = True,
Przemek Tredak's avatar
Przemek Tredak committed
178
179
180
) -> None:
    """Initialize affine weight for model parallel on GPU."""

181
182
183
184
    if set_tp_attributes:
        set_tensor_model_parallel_attributes(
            tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
        )
Przemek Tredak's avatar
Przemek Tredak committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

    if get_rng_state_tracker is None:
        init_method(weight)
        return

    with get_rng_state_tracker().fork():
        init_method(weight)


def split_tensor_into_1d_equal_chunks(
    tensor: torch.Tensor, tp_group: dist_group_type, new_buffer: bool = False
) -> torch.Tensor:
    """Break a tensor into equal 1D chunks."""
    partition_size = torch.numel(tensor) // get_distributed_world_size(tp_group)
    start_index = partition_size * get_distributed_rank(tp_group)
    end_index = start_index + partition_size
    if new_buffer:
        data = torch.empty(
            partition_size,
            dtype=tensor.dtype,
            device=torch.cuda.current_device(),
            requires_grad=False,
        )
        data.copy_(tensor.view(-1)[start_index:end_index])
    else:
        data = tensor.view(-1)[start_index:end_index]
    return data


214
def gather_split_1d_tensor(tensor: torch.Tensor, tp_group: dist_group_type) -> torch.Tensor:
Przemek Tredak's avatar
Przemek Tredak committed
215
216
217
218
219
220
221
222
    """Opposite of above function, gather values from model parallel ranks."""
    numel_gathered = torch.numel(tensor) * get_distributed_world_size(tp_group)
    gathered = torch.empty(
        numel_gathered,
        dtype=tensor.dtype,
        device=torch.cuda.current_device(),
        requires_grad=False,
    )
223
    torch.distributed.all_gather_into_tensor(gathered, tensor, group=tp_group)
Przemek Tredak's avatar
Przemek Tredak committed
224
225
226
    return gathered


227
class activation_recompute_forward(AbstractContextManager, ContextDecorator):
228
229
230
231
232
233
234
    """Context manager used to control the forward runtime behavior when executed
    under the `CheckpointFunction` function. For running FP8, the forward pass will
    run without storing intermediate activations. Instead, the forward pass saves
    the inputs tuple and the calling function. In the backwards pass, these are
    retrieved, and the forward pass is computed again while tracking the intermediate
    activations, followed by calculation of gradients using these values.
    """
235

236
237
    _is_first_fp8_module: List = []

238
    def __init__(self, activation_recompute: bool = False, recompute_phase: bool = False):
239
240
241
242
243
244
        super().__init__()
        self.activation_recompute = activation_recompute
        self.recompute_phase = recompute_phase

    def __enter__(self):
        global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
245
        _FP8_ACTIVATION_RECOMPUTE_ENABLED = (
246
247
248
249
            self.activation_recompute and FP8GlobalStateManager.is_fp8_enabled()
        )
        _FP8_ACTIVATION_RECOMPUTE_PHASE = self.recompute_phase

250
251
252
253
254
255
256
257
258
        if self.activation_recompute and not self.recompute_phase:
            activation_recompute_forward._is_first_fp8_module.append(
                FP8GlobalStateManager.IS_FIRST_FP8_MODULE
            )
        if self.activation_recompute and self.recompute_phase:
            FP8GlobalStateManager.IS_FIRST_FP8_MODULE = (
                activation_recompute_forward._is_first_fp8_module.pop(0)
            )

259
260
    def __exit__(self, *exc_details):
        global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        _FP8_ACTIVATION_RECOMPUTE_ENABLED = False
        _FP8_ACTIVATION_RECOMPUTE_PHASE = False


def is_fp8_activation_recompute_enabled() -> bool:
    """Return global boolean"""
    return _FP8_ACTIVATION_RECOMPUTE_ENABLED


def in_fp8_activation_recompute_phase() -> bool:
    """Return global boolean"""
    return _FP8_ACTIVATION_RECOMPUTE_PHASE


275
276
277
278
279
280
281
def _get_active_autocast_contexts():
    """
    Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state
    at the time of this function's execution.
    """
    autocast_cached = torch.is_autocast_cache_enabled()

282
283
284
285
286
287
288
289
290
    if torch_version() >= (2, 4, 0):
        gpu_autocast_enabled = torch.is_autocast_enabled("cuda")
        gpu_autocast_dtype = torch.get_autocast_dtype("cuda")
        gpu_autocast_ctx = torch.amp.autocast(
            "cuda",
            enabled=gpu_autocast_enabled,
            dtype=gpu_autocast_dtype,
            cache_enabled=autocast_cached,
        )
291

292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        cpu_autocast_enabled = torch.is_autocast_enabled("cpu")
        cpu_autocast_dtype = torch.get_autocast_dtype("cpu")
        cpu_autocast_ctx = torch.amp.autocast(
            "cpu",
            enabled=cpu_autocast_enabled,
            dtype=cpu_autocast_dtype,
            cache_enabled=autocast_cached,
        )
    else:
        gpu_autocast_enabled = torch.is_autocast_enabled()
        gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
        gpu_autocast_ctx = torch.cuda.amp.autocast(
            gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
        )

        cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
        cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
        cpu_autocast_ctx = torch.cpu.amp.autocast(
            cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
        )
312
313
314
315

    return gpu_autocast_ctx, cpu_autocast_ctx


316
class _CheckpointFunction(torch.autograd.Function):
Przemek Tredak's avatar
Przemek Tredak committed
317
318
319
320
321
322
323
324
325
326
327
328
    """This function is adapted from torch.utils.checkpoint with
    two main changes:
        1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
        2) the states in the model parallel tracker are also properly
           tracked/set/reset.
    """

    @staticmethod
    def forward(
        ctx,
        run_function: Callable,
        distribute_saved_activations: bool,
329
330
331
        get_rng_state_tracker: Union[Callable, None],
        tp_group: Union[dist_group_type, None],
        context_fn: Union[Callable, None],
332
        kwargs: Dict[str, Any],
Przemek Tredak's avatar
Przemek Tredak committed
333
334
        *args: Tuple[torch.Tensor, ...],
    ) -> Tuple[torch.Tensor, ...]:
335
336
        """Call forward function while saving state to be able to
        redo the computation later."""
Przemek Tredak's avatar
Przemek Tredak committed
337
338
339
340
341
        ctx.run_function = run_function
        ctx.distribute_saved_activations = distribute_saved_activations

        # Copy the rng states.
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
342
        ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
343
344
345
346
347
348
349
        if get_rng_state_tracker is not None:
            ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()

        if context_fn is not None:
            forward_ctx, recompute_ctx = context_fn()
        else:
            forward_ctx, recompute_ctx = noop_context_fn()
350
351
352
353

        # Preserve torch autocast context for the backward pass
        torch_gpu_amp_ctx, torch_cpu_amp_ctx = _get_active_autocast_contexts()

354
        with torch.no_grad(), forward_ctx:
355
            with activation_recompute_forward(activation_recompute=True, recompute_phase=False):
356
                outputs = run_function(*args, **kwargs)
Przemek Tredak's avatar
Przemek Tredak committed
357
358
359
360
361
362
363

        # Divide hidden states across model parallel group and only keep
        # the chunk corresponding to the current rank.
        if distribute_saved_activations:
            ctx.input_0_shape = args[0].data.shape
            safely_set_viewless_tensor_data(
                args[0],
364
                split_tensor_into_1d_equal_chunks(args[0].data, tp_group, new_buffer=True),
Przemek Tredak's avatar
Przemek Tredak committed
365
366
367
            )

        # Store everything.
368
369
370
371
        ctx.inputs = [arg if not torch.is_tensor(arg) else None for arg in args]
        tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args]
        ctx.save_for_backward(*tensor_inputs)

372
        fp8 = FP8GlobalStateManager.is_fp8_enabled()
373
        ctx.get_rng_state_tracker = get_rng_state_tracker
Przemek Tredak's avatar
Przemek Tredak committed
374
        ctx.tp_group = tp_group
375
        ctx.recompute_ctx = recompute_ctx
376
377
        ctx.torch_gpu_amp_ctx = torch_gpu_amp_ctx
        ctx.torch_cpu_amp_ctx = torch_cpu_amp_ctx
378
379
        ctx.fp8 = fp8
        ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
380
        ctx.kwargs = kwargs
Przemek Tredak's avatar
Przemek Tredak committed
381
382
383
384
385

        return outputs

    @staticmethod
    def backward(
386
        ctx, *args: Tuple[Union[torch.Tensor, None], ...]
Przemek Tredak's avatar
Przemek Tredak committed
387
    ) -> Tuple[Union[torch.Tensor, None], ...]:
388
        """Call backward function with activation recomputation."""
Przemek Tredak's avatar
Przemek Tredak committed
389
390
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
391
                "Checkpointing is not compatible with .grad(), please use .backward() if possible"
Przemek Tredak's avatar
Przemek Tredak committed
392
            )
393
394

        inputs = tuple(
395
            t if t is not None else arg for (t, arg) in zip(ctx.saved_tensors, ctx.inputs)
396
397
        )

398
        get_rng_state_tracker = ctx.get_rng_state_tracker
Przemek Tredak's avatar
Przemek Tredak committed
399
400
401
402

        if ctx.distribute_saved_activations:
            safely_set_viewless_tensor_data(
                inputs[0],
403
                gather_split_1d_tensor(inputs[0].data, ctx.tp_group).view(ctx.input_0_shape),
Przemek Tredak's avatar
Przemek Tredak committed
404
405
406
407
            )

        # Store the current states.
        bwd_cpu_rng_state = torch.get_rng_state()
408
        bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
409
410
        if get_rng_state_tracker is not None:
            bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
Przemek Tredak's avatar
Przemek Tredak committed
411
412
413

        # Set the states to what it used to be before the forward pass.
        torch.set_rng_state(ctx.fwd_cpu_rng_state)
414
        _set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False)
415
416
        if get_rng_state_tracker is not None:
            get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
Przemek Tredak's avatar
Przemek Tredak committed
417
418
419

        # Compute the forward pass.
        detached_inputs = detach_variable(inputs)
420
421
        with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward(
            activation_recompute=True, recompute_phase=True
422
423
        ), autocast(
            enabled=ctx.fp8, recipe=ctx.fp8_recipe
424
        ):
425
            outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
Przemek Tredak's avatar
Przemek Tredak committed
426
427
428

        # Set the states back to what it was at the start of this function.
        torch.set_rng_state(bwd_cpu_rng_state)
429
        _set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False)
430
431
        if get_rng_state_tracker is not None:
            get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker)
Przemek Tredak's avatar
Przemek Tredak committed
432
433
434

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
435
436
437
438
439
440
441
442
443

        outputs_with_grad = []
        args_with_grad = []
        for i, output in enumerate(outputs):
            if torch.is_tensor(output) and output.requires_grad:
                outputs_with_grad.append(output)
                args_with_grad.append(args[i])
        if len(outputs_with_grad) == 0:
            raise RuntimeError(
444
                "none of output has requires_grad=True, this checkpoint() is not necessary"
445
446
            )

447
448
449
        # backward does not require entering autocast context because
        # backward implementations already retrieve fp8 recipe and
        # enablement from stored ctx.
450
        torch.autograd.backward(outputs_with_grad, args_with_grad)
Przemek Tredak's avatar
Przemek Tredak committed
451
        grads = tuple(
452
            inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs
Przemek Tredak's avatar
Przemek Tredak committed
453
        )
454
455
        return (None, None, None, None, None, None) + grads

456

457
458
459
460
class _CheckpointFrame:
    """
    Storage frame for forward RNG states and detached activations from the forward recompute.
    """
461
462

    def __init__(self, recompute_fn: Callable, get_rng_state_tracker: Callable):
463
464
465
466
467
468
469
470
471
472
473
        self.recompute_fn = recompute_fn
        self.recomputed = []
        self.count = 0
        self.get_rng_state_tracker = get_rng_state_tracker
        self.fwd_rng_states = None
        self.bwd_rng_states = None

    def cache_rng_states(self, forward=True):
        """Cache fwd/bwd RNG states in the frame to restore later."""
        rng_states = (
            torch.get_rng_state(),
474
            _get_cuda_rng_state(graph_safe=False),
475
476
        )
        if self.get_rng_state_tracker is not None:
477
            rng_states += (self.get_rng_state_tracker().get_states(),)
478
479
480
481
482
483
484
485
486
487
488
489
490
491

        if forward:
            self.fwd_rng_states = rng_states
        else:
            self.bwd_rng_states = rng_states

    def restore_rng_states(self, forward=True):
        """Restore fwd/bwd RNG states that were previously cached into the frame."""
        if forward:
            rng_states = self.fwd_rng_states
        else:
            rng_states = self.bwd_rng_states

        torch.set_rng_state(rng_states[0])
492
        _set_cuda_rng_state(rng_states[1], graph_safe=False)
493
494
495
496
        if self.get_rng_state_tracker is not None:
            self.get_rng_state_tracker().set_states(rng_states[2])


497
498
499
class _recomputation_hook(
    torch.autograd.graph.saved_tensors_hooks
):  # pylint: disable=too-few-public-methods
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
    """torch.autograd hook for packing/unpacking tensors during the activation recompute phase."""

    def __init__(self, frame):

        def pack_hook(x):
            """
            Packing hook for each recomputed activation passed into the `ctx.save_for_backward()`
            call in the forward recomputation.
            """
            frame.recomputed.append(x.detach())
            return x.detach()

        def unpack_hook(x):
            """
            No-op unpack hook that will never be called because the backward pass for the
            forward recomputation is never triggered.
            """
            return x

        super().__init__(pack_hook, unpack_hook)


522
523
524
class _checkpoint_hook(
    torch.autograd.graph.saved_tensors_hooks
):  # pylint: disable=too-few-public-methods
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
    """torch.autograd hook for packing/unpacking tensors during the checkpointed forward pass."""

    def __init__(self, frame, args, kwargs):

        def pack_hook(x):
            """
            Packing hook for each tensor passed into `ctx.save_for_backward()` call in the
            forward pass. Since this is the first forward pass, we discard the tensor and instead
            pack a placeholder tensor index into the autograd engine context.
            """
            del x
            idx = frame.count
            frame.count += 1
            return idx

        def unpack_hook(idx):
            """
            Unpacking hook for each tensor that comes out of the `ctx.saved_tensors` call in the
            backward pass. The first time this is called, the _recomputation_hook will save all the
            activation tensors from `ctx.save_for_backward()` in the forward recomputation into the
            _CheckpointFrame. Subsequent calls will simply return the already recomputed activation
            tensor at the given index of the _CheckpointFrame storage.
            """

            if not frame.recomputed:
                # Store current RNG states in the backward pass
                frame.cache_rng_states(forward=False)

                # Set RNG states to what we saved before the forward pass
                frame.restore_rng_states(forward=True)

                # Recompute the forward pass
                with _recomputation_hook(frame):
                    frame.recompute_fn(*args, **kwargs)

                # Restore RNG states back to the backward pass
                frame.restore_rng_states(forward=False)

            # Return the already recomputed activation tensor at the given index
            activation = frame.recomputed[idx]
            frame.recomputed[idx] = None
            return activation

        super().__init__(pack_hook, unpack_hook)


def use_reentrant_activation_recompute():
    """Returns `True` if activation recompute is using the 'reentrant' method."""
    return _USE_REENTRANT_ACTIVATION_RECOMPUTE


def get_activation_recompute_contexts():
    """Returns context objects for the checkpointed forward pass and the forward recompute phase."""
    forward_ctx = activation_recompute_forward(
        activation_recompute=True,
        recompute_phase=False,
    )
    recompute_ctx = activation_recompute_forward(
        activation_recompute=True,
        recompute_phase=True,
    )
    return forward_ctx, recompute_ctx


589
def has_te_modules(network):
590
    """
591
    Check if there are any Transformer Engine modules in the network.
592
593
594
    """
    from .module import LayerNorm, RMSNorm
    from .module.base import TransformerEngineBaseModule
595
596
597
    from .attention.dot_product_attention.backends import UnfusedDotProductAttention
    from .attention.dot_product_attention.dot_product_attention import DotProductAttention
    from .attention.multi_head_attention import MultiheadAttention
598
    from .transformer import TransformerLayer
599

600
601
602
603
604
605
606
607
608
    te_classes_list = [
        LayerNorm,
        RMSNorm,
        TransformerEngineBaseModule,
        UnfusedDotProductAttention,
        DotProductAttention,
        MultiheadAttention,
        TransformerLayer,
    ]
609
610
611
612
613

    if isinstance(network, torch.nn.Module):
        for module in network.modules():
            if any(isinstance(module, te_class) for te_class in te_classes_list):
                return True
614
        return False
615

616
617
618
    # Cannot check for TE modules inside a custom class/callable that's not a torch.nn.Module,
    # so just assume that it has TE modules just to be safe.
    return True
Przemek Tredak's avatar
Przemek Tredak committed
619

620

621
@torch._disable_dynamo
Przemek Tredak's avatar
Przemek Tredak committed
622
623
624
def checkpoint(
    function: Callable,
    *args: Tuple[torch.Tensor, ...],
625
    **kwargs: Dict[str, Any],
Przemek Tredak's avatar
Przemek Tredak committed
626
) -> Tuple[torch.Tensor, ...]:
627
628
629
630
631
632
633
634
635
636
637
    """
    Checkpoint a part of the model by trading compute for memory. This function is based on
    `torch.utils.checkpoint.checkpoint <https://pytorch.org/docs/stable/checkpoint.html>`_.

    .. warning::

        It is the user's responsibility to ensure identical behavior when calling
        :attr:`function` from the forward and backward pass. If different output is
        produced (e.g. due to global state), then the checkpointed version won't
        be numerically equivalent.

638
639
640
641
    .. warning::
        `use_reentrant=False` does not support early stopping, and will execute the entire forward
        pass for the checkpointed module when recomputing activations in the backward pass.

642
643
644
    Parameters
    ----------
    function: Callable
645
646
            pytorch module used to run the forward and backward passes using
            the specified :attr:`args` and :attr:`kwargs`.
647
648
649
650
651
    distribute_saved_activations: bool, default = False
            if set to `True` and `use_reentrant=True`, first tensor argument is distributed
            across the specified tensor parallel group (`tp_group`) before saving it for the
            backward pass. This has no effect when `use_reentrant=False`.
    get_rng_state_tracker: `Callable`, default = None
652
            python callable which returns an instance of :func:`CudaRNGStatesTracker`.
653
654
655
656
657
    tp_group : ProcessGroup, default = None
            tensor parallel process group. Used only when `distribute_saved_activations=True`
            and `use_reentrant=True`. If `None`, it falls back to the default group.
    use_reentrant : bool, default = True
            perform checkpointing in reentrant mode.
658
659
660
661
662
    args : tuple
            tuple of torch tensors for inputs to :attr:`function`.
    kwargs : dict
            dictionary of string keys for keyword arguments to :attr:`function`.
    """
663
664
665
666
667
668
669
    # Pop out te.distributed.checkpoint() arguments
    global _USE_REENTRANT_ACTIVATION_RECOMPUTE
    _USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs.pop("use_reentrant", True)
    distribute_saved_activations = kwargs.pop("distribute_saved_activations", False)
    tp_group = kwargs.pop("tp_group", None)
    get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None)

670
    # Ensure backward compatibility.
671
672
673
674
675
676
    if (
        len(args) > 3
        and isinstance(args[0], bool)
        and callable(args[1])
        and isinstance(args[2], None | dist_group_type)
    ):
677
678
679
680
        warnings.warn(
            "Passing non-tensor non-keyword arguments is deprecated and support will be removed in "
            "future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and "
            "`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.",
681
682
            DeprecationWarning,
            stacklevel=2,
683
        )
684
685
686
        distribute_saved_activations = args[0]
        get_rng_state_tracker = args[1]
        tp_group = args[2]
687
688
        args = args[3:]

689
690
    # Trigger the native PyTorch checkpoint if the function is not or does not contain a
    # Transformer Engine module.
691
692
693
    context_fn = kwargs.pop("context_fn", noop_context_fn)
    determinism_check = kwargs.pop("determinism_check", "default")
    debug = kwargs.pop("debug", False)
694
    if not has_te_modules(function):
695
696
697
698
699
700
701
        return torch.utils.checkpoint.checkpoint(
            function,
            *args,
            use_reentrant=_USE_REENTRANT_ACTIVATION_RECOMPUTE,
            context_fn=context_fn,
            determinism_check=determinism_check,
            debug=debug,
702
            **kwargs,
703
704
        )

705
706
707
708
709
710
711
    from .module.base import TransformerEngineBaseModule

    if isinstance(function, TransformerEngineBaseModule):
        # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
        # to scatter/gather activations that we will recompute anyway.
        setattr(function, "fsdp_wrapped", False)
        setattr(function, "fsdp_group", None)
712

713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
    # Otherwise discard unused te.utils.checkpoint.checkpoint() arguments
    # and execute TE's own checkpointing
    # NOTE: This logic uses the TE checkpoint on all custom callable `function` handles because we
    #       cannot be sure there are no TE modules inside the function. It also means we might run
    #       the TE checkpoint for non-TE modules, so the TE checkpoint has to support a potential
    #       user context function.
    del determinism_check, debug
    if _USE_REENTRANT_ACTIVATION_RECOMPUTE:
        # If saved activations need to be distributed but there is no process group,
        # default to the world group.
        if distribute_saved_activations:
            assert torch.distributed.is_initialized(), "torch.distributed is not initialized."
            tp_group = torch.distributed.GroupMember.WORLD if tp_group is None else tp_group

        return _CheckpointFunction.apply(
            function,
            distribute_saved_activations,
            get_rng_state_tracker,
            tp_group,
            context_fn,
            kwargs,
            *args,
        )

    if distribute_saved_activations:
        warnings.warn(
            "`distribute_saved_activations=True` has no effect when `use_reentrant=False`. "
            "The non-reentrant checkpoint implementation does not manually store forward "
            "inputs for the activation recompute in the backward pass, and instead leverages "
            "the autograd engine's pack/unpack hooks."
        )

    user_forward_ctx, user_recompute_ctx = context_fn()
    te_forward_ctx, te_recompute_ctx = get_activation_recompute_contexts()

748
749
750
    # Preserve the torch autocast contexts from the forward pass during recompute phase.
    torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts()

751
752
753
    fp8 = FP8GlobalStateManager.is_fp8_enabled()
    fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None

754
    def recompute_fn(*args, **kwargs):
755
756
        with torch.autograd.enable_grad(), (
            te_recompute_ctx
757
758
        ), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx, autocast(
            enabled=fp8, recipe=fp8_recipe
759
        ):
760
761
762
763
764
765
            function(*args, **kwargs)

    # Initialize a new checkpoint frame for each new forward pass.
    new_frame = _CheckpointFrame(
        recompute_fn,
        get_rng_state_tracker,
Przemek Tredak's avatar
Przemek Tredak committed
766
    )
767
768
    new_frame.cache_rng_states(forward=True)

769
    with _checkpoint_hook(new_frame, args, kwargs), te_forward_ctx, user_forward_ctx:
770
        out = function(*args, **kwargs)
Przemek Tredak's avatar
Przemek Tredak committed
771

772
    return out
Przemek Tredak's avatar
Przemek Tredak committed
773

774

775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
class CudaRNGStatesTracker:
    """
    For model parallelism, multiple RNG states need to simultaneously exist in order
    to execute operations in or out of the model parallel region. This class keeps
    track of the various RNG states and provides utility methods to maintain them and
    execute parts of the model under a given RNG setting. Using the `add` method, a
    cuda rng state is initialized based on the input `seed` and is assigned to `name`.
    Later, by forking the rng state, we can perform operations and return to our starting
    cuda state.
    """

    def __init__(self):
        # Map from a string name to the cuda rng state.
        self.states_ = {}
        # Seeds are just for book keeping and ensure no seed is set twice.
        self.seeds_ = set()

    def reset(self):
        """
        Set to the initial state (no tracker).
        """
        self.states_ = {}
        self.seeds_ = set()

    def get_states(self) -> Dict[str, torch.Tensor]:
        """
        Get rng states. Copy the dictionary so we have direct pointers
        to the states, not just a pointer to the dictionary.
        """
        states = {}
        for name in self.states_:
            states[name] = self.states_[name]
        return states

    def set_states(self, states: Dict[str, torch.Tensor]) -> None:
        """
        Set the rng states. For efficiency purposes, we do not
        check the size of seed for compatibility.

        states: Dict[str, torch.Tensor]
               A mapping from string names to RNG states.
        """
        self.states_ = states

    def add(self, name: str, seed: int) -> None:
        """
        Adds a new RNG state.

        name: str
             string identifier for the RNG state.
        seed: int
             PyTorch seed for the RNG state.
        """
        # Check seed is not already used.
        if seed in self.seeds_:
830
            raise RuntimeError(f"seed {seed} already exists")
831
832
833
        self.seeds_.add(seed)
        # Check that state is not already defined.
        if name in self.states_:
834
            raise RuntimeError(f"cuda rng state {name} already exists")
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851

        if graph_safe_rng_available():
            new_state = _get_cuda_rng_state(clone=True)
            new_state.manual_seed(seed)
            self.states_[name] = new_state
            # Update global states.
            set_all_rng_states(self.states_)
        else:
            # Get the current rng state.
            orig_rng_state = _get_cuda_rng_state()
            # Set the new state and store it.
            torch.cuda.manual_seed(seed)
            self.states_[name] = _get_cuda_rng_state(clone=True)
            # Reset rng state to what it was.
            _set_cuda_rng_state(orig_rng_state)
            # Update global states.
            set_all_rng_states(self.states_)
852
853

    @contextmanager
854
    def fork(self, name: str = "model-parallel-rng"):
855
856
857
858
859
860
861
862
863
        """
        Fork the cuda rng state, perform operations, and exit with
        the original state.

        name: str
             string identifier for the RNG state.
        """
        # Check if we have added the state
        if name not in self.states_:
864
            raise KeyError(f"cuda rng state {name} is not added")
865
866
        # Get the reference to current rng state.
        orig_cuda_rng_state = _get_cuda_rng_state()
867
868
869
870
871
872
        # Set rng state to the desired one
        _set_cuda_rng_state(self.states_[name])
        # Do the stuff we wanted to do.
        try:
            yield
        finally:
873
874
875
            # this is redundant with graph-safe API
            if not graph_safe_rng_available():
                self.states_[name] = _get_cuda_rng_state()
876
877
878
879
            # And set the state to the original state we started with.
            _set_cuda_rng_state(orig_cuda_rng_state)


Przemek Tredak's avatar
Przemek Tredak committed
880
def reduce_scatter_along_first_dim(
881
    inp: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
882
) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]:
Przemek Tredak's avatar
Przemek Tredak committed
883
884
885
886
    """Reduce-scatter the input tensor across model parallel group."""
    world_size = get_distributed_world_size(tp_group)
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
887
        return inp, None
Przemek Tredak's avatar
Przemek Tredak committed
888

889
    dim_size = list(inp.size())
Przemek Tredak's avatar
Przemek Tredak committed
890
891
892
893
894
895
    assert (
        dim_size[0] % world_size == 0
    ), "First dimension of the tensor should be divisible by tensor parallel size"

    dim_size[0] = dim_size[0] // world_size

896
    output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device())
897
    handle = torch.distributed.reduce_scatter_tensor(
898
        output, inp.contiguous(), group=tp_group, async_op=async_op
Przemek Tredak's avatar
Przemek Tredak committed
899
900
901
902
    )
    return output, handle


903
def _all_gather_fp8(
904
    inp: torch.Tensor,
905
906
907
    process_group: dist_group_type,
    *,
    async_op: bool = False,
908
    quantizer: Optional[Quantizer] = None,
909
    out_shape: Optional[list[int]] = None,
910
) -> tuple[Float8TensorStorage, Optional[torch.distributed.Work]]:
911
912
913
    """All-gather FP8 tensor along first dimension."""
    world_size = get_distributed_world_size(process_group)

914
915
916
917
918
919
    # Check that quantizer is valid
    if quantizer is not None and not isinstance(
        quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
    ):
        raise ValueError(f"Got non-FP8 quantizer ({quantizer.__class__.__name__})")

920
921
    # Output tensor dims
    if out_shape is None:
922
        out_shape = list(inp.size())
923
924
        out_shape[0] *= world_size

925
926
927
    # Cast input tensor to FP8 if needed
    # Note: We cannot directly all-gather the transposed FP8 tensor,
    # so temporarily modify quantizer to avoid creating FP8 transpose.
928
    if not isinstance(inp, Float8TensorStorage):
929
930
931
932
        assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer))
        # we cannot directly gather the transposed fp8 tensor
        # so we need to disable columnwise usage for the quantizer
        # and then set it back to the original value after quantizing
933
        init_rowwise_usage = quantizer.rowwise_usage
934
        init_columnwise_usage = quantizer.columnwise_usage
935
        quantizer.set_usage(rowwise=True, columnwise=False)
936
        inp = quantizer(inp)
937
938
939
940
        quantizer.set_usage(
            rowwise=init_rowwise_usage,
            columnwise=init_columnwise_usage,
        )
941
942

    # Construct output tensor
943
    out: Float8TensorStorage
944
    if quantizer is not None:
945
946
        dtype = torch.float32
        device = "cuda"
947
948
949
        if isinstance(inp, Float8Tensor):
            dtype = inp.dtype
            device = inp.device
950
        out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
951
952
    elif isinstance(inp, Float8Tensor):
        out = inp.make_like(inp, shape=out_shape)
953
        out._data = torch.empty(
954
955
            out_shape,
            dtype=torch.uint8,
956
            device=inp.device,
957
958
959
960
        )
        out._transpose = None
        out._transpose_invalid = True
    else:
961
        raise RuntimeError("Float8TensorStorage is not supported yet without Quantizer")
962
963

    # Assume scaling factors are identical across ranks
964
    out._scale_inv = inp._scale_inv
965
966
967
968

    # Perform communication
    handle = torch.distributed.all_gather_into_tensor(
        out._data,
969
        inp._data.contiguous(),
970
971
972
973
974
        group=process_group,
        async_op=async_op,
    )

    # Make sure FP8 transpose is populated if needed
975
    needs_transpose = (
976
        quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported()
977
978
    )
    if needs_transpose:
979
980
981
982
983
984
985
986
        if handle is not None:
            handle.wait()
            handle = None
        out._create_transpose()

    return out, handle


987
988
989
990
991
992
993
994
995
def _get_quantizer_format(quantizer: Quantizer) -> Optional[bool]:
    """Get quantizer format."""
    if isinstance(quantizer, DebugQuantizer):
        quantizer = quantizer.parent_quantizer
    if isinstance(quantizer, Float8BlockQuantizer):
        return quantizer.all_gather_usage
    return None


996
997
998
999
1000
1001
1002
1003
1004
1005
def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None:
    """Make quantizer compact"""
    _quantizer = quantizer
    if isinstance(quantizer, DebugQuantizer):
        _quantizer = quantizer.parent_quantizer
    if isinstance(_quantizer, Float8BlockQuantizer):
        _quantizer.all_gather_usage = compact


def _post_process_fp8_blockwise_gather(
1006
    out: Float8BlockwiseQTensorStorage,
1007
1008
    quantizer: Float8BlockQuantizer,
    handle: Optional[torch.distributed.Work] = None,
1009
) -> Float8BlockwiseQTensorStorage:
1010
1011
1012
1013
1014
1015
1016
1017
    """Post-process FP8 blockwise gather."""
    if handle is not None:
        handle.wait()
        handle = None

    if out._is_gemm_ready_format():
        return out

1018
1019
    needs_columnwise_data_transpose = quantizer is not None and quantizer.columnwise_usage
    need_rowwise_scale_transpose = quantizer is not None and quantizer.rowwise_usage
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038

    # CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024
    # columnwise compact format means doing 128x1 quantization of it
    # so quantized tensor is 256x1024, scale inv is 2x1024
    # If we were doing GEMM_READY format, then it's equivalent to do 1x128 quantization
    # on a transposed 1024x256 tensor, so scale inv is 1024x2, cublas requries 2x1024
    # Thereforce, it turns out we don't need to transpose the scale inv, only columnwise data
    if needs_columnwise_data_transpose:
        out._transpose_columnwise_data()
    if need_rowwise_scale_transpose:
        out._rowwise_scale_inv = out._rowwise_scale_inv.transpose(-2, -1).contiguous()
    out._data_format = tex.Float8BlockScaleTensorFormat.GEMM_READY
    return out


@dataclass
class _FP8BlockwiseAllGatherAsyncHandle:
    """Handle for asynchronous FP8 blockwise all-gather."""

1039
    tensor: Float8BlockwiseQTensorStorage
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
    quantizer: Float8BlockQuantizer
    async_handle: torch.distributed.Work
    _synchronized: bool = False

    def wait(self) -> None:
        """Wait for the async operation to complete and post-process the tensor."""
        if self._synchronized:
            return
        self.async_handle.wait()
        _post_process_fp8_blockwise_gather(self.tensor, self.quantizer)
        self._synchronized = True


1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
def _all_gather_fp8_blockwise(
    inp: torch.Tensor,
    process_group: dist_group_type,
    *,
    async_op: bool = False,  # pylint: disable=unused-argument
    quantizer: Optional[Quantizer] = None,
    out_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]:
    """
    All-gather FP8 tensor along first dimension for blockwise quantization.

    Returns: quantizer(gather(inp))

1066
1067
1068
    NOTE: The implementation is only going to honor async_op=True for FP8 gather case.
    In the case where tensor shape is not divisible by 128, the implementation will fall back
    to synchronous gather and invoke the quantizer.
1069
1070
1071
1072
1073
1074
1075
1076
    """

    # Input tensor attributes
    device: torch.device
    dtype: torch.dtype
    if isinstance(inp, torch.Tensor):
        device = inp.device
        dtype = inp.dtype
1077
    elif isinstance(inp, Float8BlockwiseQTensorStorage):
1078
1079
1080
1081
1082
        if inp._rowwise_data is not None:
            device = inp._rowwise_data.device
        elif inp._columnwise_data is not None:
            device = inp._columnwise_data.device
        else:
1083
            raise ValueError("Got Float8BlockwiseQTensorStorage input tensor without any data")
1084
1085
1086
        dtype = torch.bfloat16  # Only has fp8 dtype. Guess BF16 for dequant.
    else:
        raise ValueError(
1087
1088
            "Invalid type for input tensor (expected torch.Tensor or"
            f" Float8BlockwiseQTensorStorage, found {inp.__class__.__name__})"
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
        )
    world_size = get_distributed_world_size(process_group)

    # Check that quantizer is valid
    if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer):
        raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})")
    if not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128):
        raise NotImplementedError("Only 1D blockwise quantization is supported for allgather")

    # Output tensor dims
    if out_shape is None:
        out_shape = list(inp.size())
        out_shape[0] *= world_size

    # Doing BF16 gather for now as baseline because it's simpler
1104
    if (
1105
        not isinstance(inp, Float8BlockwiseQTensorStorage)
1106
1107
1108
        and quantizer is not None
        and not quantizer.is_quantizable(inp)
    ):
1109
1110
1111
1112
1113
1114
1115
        out = torch.empty(
            out_shape,
            dtype=dtype,
            device=device,
            memory_format=torch.contiguous_format,
        )
        torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
1116
1117
        orig_all_gather_usage = quantizer.all_gather_usage
        quantizer.all_gather_usage = False
1118
        out = quantizer(out)
1119
        quantizer.all_gather_usage = orig_all_gather_usage
1120
        return out, None
1121

1122
1123
1124
    # Implementation of fp8 gather needs to account for:
    # * Getting columnwise data as a transpose of how it is stored for GEMMS.
    # * Gathering non GEMM swizzled scales.
1125
1126
1127
1128
1129

    # Cast input tensor to Float8BlockwiseQTensor with required data
    # Set to compact usage in case the quantizer is not correctly configured
    orig_all_gather_usage = quantizer.all_gather_usage
    quantizer.all_gather_usage = True
1130
    if not isinstance(inp, Float8BlockwiseQTensorStorage):
1131
1132
1133
1134
1135
1136
1137
1138
1139
        inp = quantizer(inp)
    elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
        quantizer.columnwise_usage and inp._columnwise_data is None
    ):
        warnings.warn(
            "Input and quantizer do not have matching usages. "
            "Dequantizing and requantizing to Float8BlockwiseQTensor."
        )
        inp = quantizer(inp.dequantize())
1140
1141
1142
1143

    # Construct Float8BlockwiseQTensor output tensor
    out = quantizer.make_empty(out_shape, dtype=dtype, device=device)

1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
    quantizer.all_gather_usage = orig_all_gather_usage

    # Begin to do network communication, need to make sure compact format
    if inp._data_format != tex.Float8BlockScaleTensorFormat.COMPACT:
        raise RuntimeError(
            "All-gather with FP8 block-wise quantized tensor requires compact data format, "
            f"but found data_format={inp._data_format}"
        )

    # Coalesce NCCL collectives
    with torch.distributed._coalescing_manager(
        group=process_group,
        device=device,
        async_ops=async_op,
    ) as coalescing_manager:

        # Gather Float8BlockwiseQTensor data for row-wise usage
        if quantizer.rowwise_usage:
            # Launch all-gathers
            torch.distributed.all_gather_into_tensor(
                out._rowwise_scale_inv,
                inp._rowwise_scale_inv,
                group=process_group,
            )
            torch.distributed.all_gather_into_tensor(
                out._rowwise_data,
                inp._rowwise_data,
                group=process_group,
            )

        # Gather Float8BlockwiseQTensor data for column-wise usage
        if quantizer.columnwise_usage:
            # Launch all-gathers
            torch.distributed.all_gather_into_tensor(
                out._columnwise_scale_inv,
                inp._columnwise_scale_inv,
                group=process_group,
            )
            torch.distributed.all_gather_into_tensor(
                out._columnwise_data,
                inp._columnwise_data,
                group=process_group,
            )

    handle = coalescing_manager if async_op else None

    # Unlike MXFP8, this fp8 blockwise tensor primarily works with Hopper
    # This means that we need to transpose the gathered columnwise data
    # Example usage is grad_output tensor, ie. dY in linear backward
    # We want to gather two FP8 tensors (rowwise and columnwise) along dim0
    # and then transpose the columnwise data to match the rowwise data
    # Make sure FP8 transpose is populated if needed

    if async_op:
        handle = _FP8BlockwiseAllGatherAsyncHandle(out, quantizer, handle)
    else:
        # if it's a sync op, we need to do the transpose here as post processing step
        _post_process_fp8_blockwise_gather(out, quantizer, handle)

    return out, handle
1204
1205


1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
def _swap_first_dims(tensor: torch.Tensor, world_size: int):
    """
    Swap first 2 dimensions of a tensor to fix interleaved
    data format after gathering transposed data.

    For more than 2 dimensions, we squash the trailing dimensions,
    instead of the first few dimensions, that's because the shape
    passed in this function is already transposed.
    """

    shape = tensor.shape
    assert tensor.ndim >= 2, "Wrong number of dimensions for fixing interleave."
    first_dim = shape[0]
    flattened_trailing = math.prod(shape[1:])
    assert first_dim % world_size == 0, "Wrong dimensions for fixing interleave."
    tensor = tensor.reshape(world_size, first_dim // world_size, flattened_trailing)
    tensor = tex.swap_first_dims(tensor, out=None)
    return tensor.reshape(first_dim // world_size, flattened_trailing * world_size)


def _post_process_nvfp4_gather(
1227
    out: NVFP4TensorStorage,
1228
1229
1230
1231
    columnwise_data_interleaved: torch.Tensor,
    columnwise_scale_inv_interleaved: torch.Tensor,
    world_size: int,
    handle: Optional[torch.distributed.Work] = None,
1232
) -> NVFP4TensorStorage:
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
    """Post-process FP8 blockwise gather."""
    if handle is not None:
        handle.wait()
        handle = None

    # Fix the interleaved transposed data from gathering along first dim.
    out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size)
    out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size)

    # Optionally pad the scaling inverse if needed.
    out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv)


@dataclass
class _NVFP4AllGatherAsyncHandle:
    """Handle for asynchronous NVFP4 all-gather."""

1250
    output: NVFP4TensorStorage
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
    columnwise_data_interleaved: torch.Tensor
    columnwise_scale_inv_interleaved: torch.Tensor
    world_size: int
    async_handle: torch.distributed.Work
    _synchronized: bool = False

    def wait(self) -> None:
        """Wait for the async operation to complete and post-process the tensor."""
        if self._synchronized:
            return
        self.async_handle.wait()
        _post_process_nvfp4_gather(
            self.output,
            self.columnwise_data_interleaved,
            self.columnwise_scale_inv_interleaved,
            self.world_size,
        )
        self._synchronized = True


def _all_gather_nvfp4(
    inp: torch.Tensor,
    process_group: dist_group_type,
    *,
    async_op: bool = False,
    quantizer: NVFP4Quantizer,
    out_shape: Optional[list[int]] = None,
1278
) -> tuple[NVFP4TensorStorage, Optional[torch.distributed.Work]]:
1279
1280
1281
1282
1283
1284
1285
1286
1287
    """All-gather NVFP4 tensor along first dimension."""

    # Input tensor attributes
    in_shape: Iterable[int] = None
    in_shape_t: Iterable[int] = None
    device: torch.device
    dtype: torch.dtype

    # Construct packed shapes for input and input_t.
1288
    if isinstance(inp, torch.Tensor) and not isinstance(inp, NVFP4TensorStorage):
1289
1290
1291
1292
1293
1294
1295
        # High-precision tensor.
        in_shape = NVFP4Quantizer.convert_shape_for_fp4(inp.size())
        in_shape_t = NVFP4Quantizer.convert_shape_for_fp4(
            NVFP4Quantizer.get_columnwise_shape(inp.size())
        )
        device = inp.device
        dtype = inp.dtype
1296
    elif isinstance(inp, NVFP4TensorStorage):
1297
1298
1299
1300
1301
1302
1303
1304
1305
        if inp._rowwise_data is not None:
            in_shape = inp._rowwise_data.size()
            device = inp._rowwise_data.device
        if inp._columnwise_data is not None:
            in_shape_t = inp._columnwise_data.size()
            device = inp._columnwise_data.device
        dtype = torch.bfloat16
    else:
        raise ValueError(
1306
            "Invalid type for input tensor (expected torch.Tensor or NVFP4TensorStorage, "
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
            f"found {inp.__class__.__name__})"
        )

    assert in_shape is not None or in_shape_t is not None, "No data found."

    world_size = get_distributed_world_size(process_group)

    if out_shape is None:
        out_shape = [in_shape[0] * world_size] + in_shape[1:]

    # For cases where inp has dimensions that cannot be quantized,
    # we gather in high precision followed by a cast to NVFP4.
    if (
1320
        not isinstance(inp, NVFP4TensorStorage)
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
        and quantizer is not None
        and not quantizer.is_quantizable(inp)
    ):
        out = torch.empty(
            out_shape,
            dtype=dtype,
            device=device,
            memory_format=torch.contiguous_format,
        )
        torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
        out = quantizer(out)
        return out, None

    # Cast input tensor to NVFP4 with required data
1335
    if not isinstance(inp, NVFP4TensorStorage):
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
        inp = quantizer(inp)
    elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
        quantizer.columnwise_usage and inp._columnwise_data is None
    ):
        warnings.warn(
            "Input and quantizer do not have matching usages. "
            "Dequantizing and requantizing to NVFP4."
        )
        inp = quantizer(inp.dequantize())

    # Construct NVFP4 output tensor
    out = quantizer.make_empty(out_shape, dtype=dtype, device=device)

    # Coalesce NCCL collectives for gathering data and scale inverses.
    with torch.distributed._coalescing_manager(
        group=process_group,
        device=device,
        async_ops=async_op,
    ) as gather_coalescing_manager:

        # Gather NVFP4 data for row-wise usage
        if quantizer.rowwise_usage:

            # Remove padding from NVFP4 scale-inverses
            assert in_shape is not None, "Shape not found."
            in_scale_inv = inp._rowwise_scale_inv
            out_scale_inv = out._rowwise_scale_inv
            flattened_in_shape0 = math.prod(in_shape[:-1])
            if in_scale_inv.size(0) != flattened_in_shape0:
                in_scale_inv = in_scale_inv[:flattened_in_shape0]
                out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]

            # Launch all-gathers
            torch.distributed.all_gather_into_tensor(
                out_scale_inv,
                in_scale_inv,
                group=process_group,
            )
            torch.distributed.all_gather_into_tensor(
                out._rowwise_data,
                inp._rowwise_data,
                group=process_group,
            )

            # Transfer amax to output.
            out._amax_rowwise = inp._amax_rowwise

        # Gather the transposed NVFP4 data along first dimension. Fix format later.
        if quantizer.columnwise_usage:

            # Remove padding from NVFP4 scale-inverses
            # For doing an all-gather on transposed scale inverses,
            # we need to remove padding from both dimension.
            in_scale_inv = inp._columnwise_scale_inv
            # take caution that for in_shape_t, flatten in the trailing dimensions!
            flattened_in_shape0 = in_shape_t[0]
            flattened_in_shape1 = math.prod(in_shape_t[1:])

            # Remove dim0 padding
            if in_scale_inv.size(0) != flattened_in_shape0:
                in_scale_inv = in_scale_inv[:flattened_in_shape0]

            # Remove dim1 padding (pack first).
            unpadded_dim1 = flattened_in_shape1 * 2 // 16
            if in_scale_inv.size(1) != unpadded_dim1:
                in_scale_inv = in_scale_inv[:, :unpadded_dim1].contiguous()

            # Construct tensor to gather transposed scale_inv (interleaved) and launch AG.
            out_scale_inv = torch.empty(
                [flattened_in_shape0 * world_size] + [in_scale_inv.shape[1]],
                dtype=in_scale_inv.dtype,
                layout=in_scale_inv.layout,
                device=in_scale_inv.device,
            )
            torch.distributed.all_gather_into_tensor(
                out_scale_inv,
                in_scale_inv,
                group=process_group,
            )

            # Construct tensor to gather transposed data (interleaved) and launch AG.
            out_columnwise_data = torch.empty(
                [inp._columnwise_data.shape[0] * world_size] + list(inp._columnwise_data.shape[1:]),
                dtype=inp._columnwise_data.dtype,
                layout=inp._columnwise_data.layout,
                device=inp._columnwise_data.device,
            )
            torch.distributed.all_gather_into_tensor(
                out_columnwise_data,
                inp._columnwise_data,
                group=process_group,
            )

            # Transfer amax to output.
            out._amax_columnwise = inp._amax_columnwise

    handle = gather_coalescing_manager if async_op else None

    # Fixes interleaved data for transposed tensor/scale inv and pads scale inv if needed.
    if async_op and quantizer.columnwise_usage:
        handle = _NVFP4AllGatherAsyncHandle(
            out, out_columnwise_data, out_scale_inv, world_size, handle
        )
    elif quantizer.columnwise_usage:
        _post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle)

    return out, handle


1445
def _all_gather_mxfp8(
1446
    inp: torch.Tensor,
1447
1448
1449
1450
1451
    process_group: dist_group_type,
    *,
    async_op: bool = False,
    quantizer: MXFP8Quantizer,
    out_shape: Optional[list[int]] = None,
1452
) -> tuple[MXFP8TensorStorage, Optional[torch.distributed.Work]]:
1453
1454
    """All-gather MXFP8 tensor along first dimension."""

1455
1456
1457
1458
1459
1460
1461
1462
    # Input tensor attributes
    in_shape: Iterable[int]
    device: torch.device
    dtype: torch.dtype
    if isinstance(inp, torch.Tensor):
        in_shape = inp.size()
        device = inp.device
        dtype = inp.dtype
1463
    elif isinstance(inp, MXFP8TensorStorage):
1464
        if inp._rowwise_data is not None:
1465
            in_shape = inp._rowwise_data.size()
1466
1467
            device = inp._rowwise_data.device
        elif inp._columnwise_data is not None:
1468
            in_shape = inp._columnwise_data.size()
1469
1470
1471
            device = inp._columnwise_data.device
        else:
            raise ValueError("Got MXFP8 input tensor without any data")
1472
        dtype = torch.bfloat16  # Guess high-precision dtype.
1473
1474
    else:
        raise ValueError(
1475
            "Invalid type for input tensor (expected torch.Tensor or MXFP8TensorStorage, "
1476
1477
1478
1479
            f"found {inp.__class__.__name__})"
        )

    # Output tensor shape
1480
1481
1482
1483
    world_size = get_distributed_world_size(process_group)
    if out_shape is None:
        out_shape = [in_shape[0] * world_size] + in_shape[1:]

1484
1485
1486
    # For cases where inp has dimensions that cannot be quantized,
    # we gather in high precision followed by a cast to FP8.
    if (
1487
        not isinstance(inp, MXFP8TensorStorage)
1488
1489
1490
1491
1492
        and quantizer is not None
        and not quantizer.is_quantizable(inp)
    ):
        out = torch.empty(
            out_shape,
1493
1494
            dtype=dtype,
            device=device,
1495
1496
1497
1498
1499
1500
1501
            memory_format=torch.contiguous_format,
        )
        torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
        out = quantizer(out)
        return out, None

    # Cast input tensor to MXFP8 with required data
1502
    if not isinstance(inp, MXFP8TensorStorage):
1503
        inp = quantizer(inp)
1504
1505
    elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
        quantizer.columnwise_usage and inp._columnwise_data is None
1506
1507
1508
1509
1510
1511
    ):
        warnings.warn(
            "Input and quantizer do not have matching usages. "
            "Dequantizing and requantizing to MXFP8."
        )
        inp = quantizer(inp.dequantize())
1512

1513
    # Construct MXFP8 output tensor
1514
    out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
1515

1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
    # Coalesce NCCL collectives
    with torch.distributed._coalescing_manager(
        group=process_group,
        device=device,
        async_ops=async_op,
    ) as coalescing_manager:

        # Gather MXFP8 data for row-wise usage
        if quantizer.rowwise_usage:

            # Remove padding from MXFP8 scale-inverses
            in_scale_inv = inp._rowwise_scale_inv
            out_scale_inv = out._rowwise_scale_inv
            flattened_in_shape0 = math.prod(in_shape[:-1])
            if in_scale_inv.size(0) != flattened_in_shape0:
                in_scale_inv = in_scale_inv[:flattened_in_shape0]
                out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]

            # Launch all-gathers
            torch.distributed.all_gather_into_tensor(
                out_scale_inv,
                in_scale_inv,
                group=process_group,
            )
            torch.distributed.all_gather_into_tensor(
                out._rowwise_data,
                inp._rowwise_data,
                group=process_group,
            )
1545

1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
        # Gather MXFP8 data for column-wise usage
        if quantizer.columnwise_usage:

            # Remove padding from MXFP8 scale-inverses
            in_scale_inv = inp._columnwise_scale_inv
            out_scale_inv = out._columnwise_scale_inv
            flattened_in_shape0 = math.prod(in_shape[:-1]) // 32
            if in_scale_inv.size(0) != flattened_in_shape0:
                in_scale_inv = in_scale_inv[:flattened_in_shape0]
                out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]

            # Launch all-gathers
            torch.distributed.all_gather_into_tensor(
                out_scale_inv,
                in_scale_inv,
                group=process_group,
            )
            torch.distributed.all_gather_into_tensor(
                out._columnwise_data,
                inp._columnwise_data,
                group=process_group,
            )
1568

1569
    handle = coalescing_manager if async_op else None
1570
    return out, handle
1571
1572


Przemek Tredak's avatar
Przemek Tredak committed
1573
def gather_along_first_dim(
1574
    inp: torch.Tensor,
1575
1576
    process_group: dist_group_type,
    async_op: bool = False,
1577
1578
    quantizer: Optional[Quantizer] = None,
) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]:
1579
1580
1581
    """
    All-gather tensors and concatenate along first dimension.
    """
Przemek Tredak's avatar
Przemek Tredak committed
1582

1583
1584
    # Return immediately if no communication is required
    world_size = get_distributed_world_size(process_group)
Przemek Tredak's avatar
Przemek Tredak committed
1585
    if world_size == 1:
1586
        if quantizer is not None and not isinstance(inp, QuantizedTensorStorage):
1587
1588
            inp = quantizer(inp)
        return inp, None
Przemek Tredak's avatar
Przemek Tredak committed
1589

1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
    # Debug case - call gather_along_first_dim on each tensor
    if isinstance(inp, DebugQuantizedTensor):
        out_obj = DebugQuantizedTensor(
            rowwise_gemm_tensor=inp.rowwise_gemm_tensor,
            columnwise_gemm_tensor=inp.columnwise_gemm_tensor,
            quantizer=inp.quantizer,
            layer_name=inp._layer_name,
            tensor_name=inp._tensor_name,
        )
        rowwise = inp.get_tensor(False)
        columnwise = inp.get_tensor(True)
        # shapes
        final_quantizer = (
            None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer
        )
        rowwise_total = None
        if rowwise is not None:
            rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[
                0
            ]
        out_obj.rowwise_gemm_tensor = rowwise_total
        if rowwise is not columnwise:
            final_quantizer_columnwise = (
                None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer
            )
            columnwise_total = None
            if columnwise is not None:
                columnwise_total, _ = gather_along_first_dim(
                    columnwise, process_group, False, final_quantizer_columnwise
                )
            out_obj.columnwise_gemm_tensor = columnwise_total
        else:
            # Sometimes the same object is used both for rowwise and columnwise gemms,
            # and we want to avoid double all-gathers.
            out_obj.columnwise_gemm_tensor = out_obj.rowwise_gemm_tensor

        return out_obj, None

1628
    # Output tensor dims
1629
    out_shape = list(inp.size())
1630
1631
    out_shape[0] *= world_size

1632
    # FP8 case: delayed scaling or current scaling
1633
    if isinstance(inp, Float8TensorStorage) or isinstance(
1634
1635
        quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
    ):
1636
        return _all_gather_fp8(
1637
            inp,
1638
1639
1640
1641
            process_group,
            async_op=async_op,
            quantizer=quantizer,
            out_shape=out_shape,
1642
        )
1643

1644
    # FP8 block scaling case, block length = 128
1645
1646
1647
    if isinstance(inp, Float8BlockwiseQTensorStorage) or isinstance(
        quantizer, Float8BlockQuantizer
    ):
1648
1649
1650
1651
1652
1653
1654
1655
        return _all_gather_fp8_blockwise(
            inp,
            process_group,
            async_op=async_op,
            quantizer=quantizer,
            out_shape=out_shape,
        )

1656
    # MXFP8 case
1657
    if isinstance(inp, MXFP8TensorStorage) or isinstance(quantizer, MXFP8Quantizer):
1658
1659
        assert isinstance(quantizer, MXFP8Quantizer)
        return _all_gather_mxfp8(
1660
            inp,
1661
1662
1663
1664
1665
1666
            process_group,
            async_op=async_op,
            quantizer=quantizer,
            out_shape=out_shape,
        )

1667
    # NVFP4 case
1668
    if isinstance(inp, NVFP4TensorStorage) or isinstance(quantizer, NVFP4Quantizer):
1669
1670
1671
1672
1673
1674
1675
1676
1677
        assert isinstance(quantizer, NVFP4Quantizer)
        return _all_gather_nvfp4(
            inp,
            process_group,
            async_op=async_op,
            quantizer=quantizer,
            out_shape=out_shape,
        )

1678
1679
1680
1681
1682
1683
    # High-precision communication for quantized tensors
    if quantizer is not None:
        warnings.warn(
            "Attempting to all-gather an unsupported quantized tensor. "
            "Falling back to high-precision all-gather."
        )
1684
        if isinstance(inp, QuantizedTensorStorage):
1685
            inp = inp.dequantize()
1686
1687
        # Falling back to high-precision all-gather for Float8BlockQuantizer
        # means that it should directly output GEMM_READY format
1688
        compact = _get_quantizer_format(quantizer)
1689
        _set_quantizer_format(quantizer, compact=False)
1690
1691
        out = torch.empty(
            out_shape,
1692
1693
            dtype=inp.dtype,
            device=inp.device,
1694
1695
            memory_format=torch.contiguous_format,
        )
1696
        torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
1697
        out = quantizer(out)
1698
        _set_quantizer_format(quantizer, compact=compact)
1699
        return out, None
Przemek Tredak's avatar
Przemek Tredak committed
1700

1701
    # Dequantize quantized tensor if not supported
1702
    if isinstance(inp, QuantizedTensorStorage):
1703
1704
1705
1706
        warnings.warn(
            "Attempting to all-gather an unsupported quantized tensor. "
            "Falling back to high-precision all-gather."
        )
1707
        inp = inp.dequantize()
1708
1709
1710
1711

    # Communication for plain PyTorch tensors
    out = torch.empty(
        out_shape,
1712
1713
        dtype=inp.dtype,
        device=inp.device,
1714
1715
        memory_format=torch.contiguous_format,
    )
1716
    handle = torch.distributed.all_gather_into_tensor(
1717
        out,
1718
        inp.contiguous(),
1719
1720
        group=process_group,
        async_op=async_op,
Przemek Tredak's avatar
Przemek Tredak committed
1721
    )
1722
    return out, handle
Przemek Tredak's avatar
Przemek Tredak committed
1723
1724


1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
# Global cache to store symmetric memory tensors
symmetric_mem_cache = {}


def get_symmetric_memory_tensor(tensor_numel, tensor_dtype, tensor_device, tp_group, tag=None):
    """
    Gets or creates a symmetric memory tensor with specified properties.

    Reuses cached tensors when available to avoid redundant creation and rendezvous operations.

    Note: This function always returns a 1D tensor.

    Parameters
    ----------
    tensor_numel : int
        Number of elements in the tensor.
    tensor_dtype : torch.dtype
        Data type of the tensor.
    tensor_device : torch.device
        Device on which to allocate the tensor.
    tp_group : dist_group_type
        Process group for rendezvous operation.
    tag : Any, optional
        Optional identifier to further distinguish tensors.

    Returns
    -------
    torch.Tensor
        A symmetric memory tensor with the specified properties.
    """
    # Create a cache key based on tensor properties and group
    cache_key = (tensor_numel, tensor_dtype, tensor_device, tp_group.group_name, tag)

    # Check if we already have a symmetric memory tensor for this configuration
    if cache_key not in symmetric_mem_cache:
        # Create a new symmetric memory tensor if not in cache
        msg = symm_mem.empty(
            tensor_numel,
            dtype=tensor_dtype,
            device=tensor_device,
        )
        # Perform the rendezvous once for this tensor
        symm_mem.rendezvous(msg, group=tp_group)
        # Store in cache
        symmetric_mem_cache[cache_key] = msg
    else:
        # Reuse the existing symmetric memory tensor
        msg = symmetric_mem_cache[cache_key]

    return msg


def symmetric_all_reduce(
    inp: torch.Tensor,
    tp_group: Optional[dist_group_type] = None,
    async_op: bool = False,
    all_reduce_type: str = "multimem_all_reduce",
):
    """
    Performs an all-reduce operation across multiple processes using symmetric memory.
    If the input tensor is already in the symmetric memory cache we can avoid copy
    overheads by just directly using the input tensor for all reduce.  Externally
    created symmetric memory tensors not in the cache currently will not be able to
    avoid the extra copies.

    Parameters
    ----------
    inp : torch.Tensor
        The input tensor to be reduced. The operation is performed in-place.

    tp_group : Optional[dist_group_type], default=None
        The process group over which to perform the all-reduce operation.
        If None, the default process group is used.

    async_op : bool, default=False
        Whether to perform the operation asynchronously.
        Note: Currently only synchronous operations are supported for symmetric memory variants.

    all_reduce_type : str, default="multimem_all_reduce"
        The type of all-reduce implementation to use. Options include:
        - "nccl": Standard PyTorch distributed all-reduce
        - "multimem_all_reduce": multimem symmetric all-reduce
        - "two_shot": Two-shot symmetric all-reduce
        - "one_shot": One-shot symmetric all-reduce

    Returns
    -------
    Tuple[torch.Tensor, Optional[torch.distributed.Work]]
        - The first element is the input tensor with the all-reduce result.
        - The second element is the async work handle if async_op=True,
          otherwise None.
    """
    assert async_op is False, "Async symmetric ops no supported yet"
    assert HAS_TORCH_SYMMETRIC, "Could not import symetric memory from torch"

    if get_distributed_world_size(tp_group) == 1:
        return inp, None

    if all_reduce_type == "nccl":
        # Standard all-reduce implementation
        handle = torch.distributed.all_reduce(inp, group=tp_group, async_op=async_op)
        return inp, handle

    all_reduce_impl = None
    if all_reduce_type == "multimem_all_reduce":
        all_reduce_impl = torch.ops.symm_mem.multimem_all_reduce_
    elif all_reduce_type == "two_shot":
        all_reduce_impl = torch.ops.symm_mem.two_shot_all_reduce_
    elif all_reduce_type == "one_shot":
        all_reduce_impl = torch.ops.symm_mem.one_shot_all_reduce
    else:
        raise TypeError(f"All reduce type {all_reduce_type} is not supported.")

    group_name = tp_group.group_name
    tensor_shape = inp.shape
    tensor_numel = inp.numel()
    tensor_dtype = inp.dtype
    tensor_device = inp.device

    input_id = id(inp)
    is_cached = any(id(cached_tensor) == input_id for cached_tensor in symmetric_mem_cache.values())
    # Check if the input tensor is already in the symmetric memory cache. If it is we can avoid copy overheads.
    if is_cached:
        all_reduce_impl(
            inp,
            "sum",
            group_name,
        )
    else:
        # Get symmetric memory tensor. Build or retrieve from cache.
        msg = get_symmetric_memory_tensor(tensor_numel, tensor_dtype, tensor_device, tp_group)

        msg.copy_(inp.reshape(-1))

        all_reduce_impl(
            msg,
            "sum",
            group_name,
        )

        # Copy the result back to the input tensor
        inp.copy_(msg.reshape(tensor_shape))

    return inp, None


Przemek Tredak's avatar
Przemek Tredak committed
1871
def allreduce(
1872
    inp: torch.Tensor,
Przemek Tredak's avatar
Przemek Tredak committed
1873
1874
    tp_group: Optional[dist_group_type] = None,
    async_op: bool = False,
1875
) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]:
Przemek Tredak's avatar
Przemek Tredak committed
1876
1877
1878
1879
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if get_distributed_world_size(tp_group) == 1:
1880
        return inp, None
Przemek Tredak's avatar
Przemek Tredak committed
1881
1882

    # All-reduce.
1883
    handle = torch.distributed.all_reduce(inp, group=tp_group, async_op=async_op)
Przemek Tredak's avatar
Przemek Tredak committed
1884

1885
    return inp, handle
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895


def _fsdp_scatter_tensors(
    fsdp_group: dist_group_type,
    *tensors: torch.Tensor,
):
    shapes = []
    if fsdp_group is not None:
        for t in tensors:
            if isinstance(t, torch.Tensor):
1896
1897
1898
1899
1900
1901
1902
                targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t]
                for target in targets:
                    shapes.append(target.data.shape)
                    safely_set_viewless_tensor_data(
                        target,
                        split_tensor_into_1d_equal_chunks(target.data, fsdp_group, new_buffer=True),
                    )
1903
1904
1905
1906
1907
1908
1909
            else:
                shapes.append(None)
    return shapes


def _fsdp_gather_tensors(
    fsdp_group: dist_group_type,
1910
    shapes: List[Tuple[int, ...]],
1911
1912
1913
1914
1915
1916
1917
    *tensors: torch.Tensor,
):
    if fsdp_group is not None:
        assert len(shapes) == len(tensors), "Number of tensors and tensor shapes must be equal."
        for s, t in zip(shapes, tensors):
            if isinstance(t, torch.Tensor):
                assert s is not None, "Internal TE error."
1918
1919
1920
1921
1922
                targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t]
                for target in targets:
                    safely_set_viewless_tensor_data(
                        target, gather_split_1d_tensor(target.data, fsdp_group).view(s)
                    )
1923
1924
1925
1926
1927
1928
1929
1930
1931


def _is_te_module(module):
    """
    Check if given module is a Transformer Engine module that requires the TE checkpoint
    implementation for activation recompute.
    """
    from .module import LayerNorm, RMSNorm
    from .module.base import TransformerEngineBaseModule
1932
1933
1934
    from .attention.dot_product_attention.dot_product_attention import DotProductAttention
    from .attention.dot_product_attention.backends import UnfusedDotProductAttention
    from .attention.multi_head_attention import MultiheadAttention
1935
    from .transformer import TransformerLayer
1936

1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
    te_classes_list = [
        LayerNorm,
        RMSNorm,
        TransformerEngineBaseModule,
        UnfusedDotProductAttention,
        DotProductAttention,
        MultiheadAttention,
        TransformerLayer,
    ]
    is_te_module = False
    for te_class in te_classes_list:
        if isinstance(module, te_class):
            is_te_module = True
            break
    return is_te_module


def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
    """
    Inject FSDP process gorup references into FSDP-wrapped TE modules in an FSDP-wrapped root
    module in order to scatter/gather the Fp8 weight copies at the same time FSDP scatters/gathers
    its `FlatParameters`.

    Parameters
    ----------
    fsdp_root: torch.nn.Module
               FSDP-wrapped root module that may contain FSDP-wrapped TE modules.
    """
    assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped."

    # If the root module is a TE module, inject FSDP information into it
    if _is_te_module(fsdp_root.module):
1969
1970
1971
        if hasattr(fsdp_root, "primary_weights_in_fp8"):
            assert not fsdp_root.primary_weights_in_fp8, (
                "TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
1972
                "Please initialize your model without the te.quantized_model_init(...) context."
1973
            )
1974
1975
1976
1977
1978
1979
1980
1981
        root_state = _get_module_fsdp_state(fsdp_root)
        assert root_state is not None, "Root module does not have a valid _FSDPState."
        setattr(fsdp_root.module, "fsdp_group", root_state.process_group)

    # Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules
    fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root)
    for state, fsdp_module in zip(fsdp_states, fsdp_modules):
        if _is_te_module(fsdp_module.module):
1982
1983
1984
            if hasattr(fsdp_module.module, "primary_weights_in_fp8"):
                assert not fsdp_module.module.primary_weights_in_fp8, (
                    "TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
1985
                    "Please initialize your model without the te.quantized_model_init(...) context."
1986
                )
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
            setattr(fsdp_module.module, "fsdp_group", state.process_group)


class FullyShardedDataParallel(FSDP):
    """
    Transformer Engine wrapper around `torch.distributed.fsdp.FullyShardedDataParallel` that
    extracts necessary information out of the FSDP wrap for TE modules to scatter their
    activation tensors after each forward pass and gather them before the backward pass.
    """

    def __init__(self, module, *args, **kwargs):
        super().__init__(module, *args, **kwargs)
        prepare_te_modules_for_fsdp(self)