distributed.py 51.1 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
#
# See LICENSE for license information.

"""Methods needed for distributed training (DP/TP)."""
6
7
from __future__ import annotations

8
from collections.abc import Iterable
9
from contextlib import contextmanager, AbstractContextManager, ContextDecorator
10
from functools import lru_cache
11
import math
12
13
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
14

Przemek Tredak's avatar
Przemek Tredak committed
15
import torch
16
from torch.cuda import _lazy_call, _lazy_init
17
from torch.utils.checkpoint import detach_variable, noop_context_fn
18
19
20
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
Przemek Tredak's avatar
Przemek Tredak committed
21

22
from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data, needs_quantized_gemm
Przemek Tredak's avatar
Przemek Tredak committed
23
from .constants import dist_group_type
24
from .fp8 import FP8GlobalStateManager, fp8_autocast
25
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
26
from .tensor.mxfp8_tensor import MXFP8Quantizer
27
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
28
29
30
from .tensor.quantized_tensor import QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase
from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
31
from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
32
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
33
34
35
36

__all__ = ["checkpoint", "CudaRNGStatesTracker"]


Przemek Tredak's avatar
Przemek Tredak committed
37
38
39
40
41
42
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
    "tensor_model_parallel": False,
    "partition_dim": -1,
    "partition_stride": 1,
}

43
44
_USE_REENTRANT_ACTIVATION_RECOMPUTE = True

45
46
47
_FP8_ACTIVATION_RECOMPUTE_ENABLED = False
_FP8_ACTIVATION_RECOMPUTE_PHASE = False

Przemek Tredak's avatar
Przemek Tredak committed
48

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
_ALL_ACTIVE_RNG_STATES = {}


def get_all_rng_states() -> bool:
    """Returns all generator states used by `CudaRNGStatesTracker`."""
    return _ALL_ACTIVE_RNG_STATES


def set_all_rng_states(states: List) -> None:
    """Updates all generator states used by `CudaRNGStatesTracker`."""
    global _ALL_ACTIVE_RNG_STATES
    _ALL_ACTIVE_RNG_STATES = states


def graph_safe_rng_available() -> bool:
    """Returns whether cuda graph safe RNG state manipulation is supported."""
65
66
67
68
69
70
    return (
        hasattr(torch.cuda.CUDAGraph, "register_generator_state")
        and hasattr(torch.Generator, "graphsafe_set_state")
        and hasattr(torch.Generator, "graphsafe_get_state")
        and hasattr(torch.Generator, "clone_state")
    )
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100


def _get_cuda_rng_state(
    device: Union[int, str, torch.device] = "cuda",
    clone: bool = False,
    graph_safe: bool = True,
) -> torch.Tensor:
    """Return the random number generator state of the specified GPU."""

    _lazy_init()
    if isinstance(device, str):
        device = torch.device(device)
    elif isinstance(device, int):
        device = torch.device("cuda", device)
    idx = device.index
    if idx is None:
        idx = torch.cuda.current_device()
    default_generator = torch.cuda.default_generators[idx]
    if graph_safe_rng_available() and graph_safe:
        if clone:
            # Reference to the cloned generator state
            return default_generator.clone_state()
        # Reference to the current generator state
        return default_generator.graphsafe_get_state()
    return default_generator.get_state()


def _set_cuda_rng_state(
    new_state: torch.Tensor,
    device: Union[int, str] = -1,
101
    graph_safe=True,
102
103
) -> None:
    """Sets the random number generator state of the current GPU."""
Przemek Tredak's avatar
Przemek Tredak committed
104
105
106
107
108
109
110
111
112
113
114
115
116

    if device == -1:
        device = torch.device("cuda")
    elif isinstance(device, str):
        device = torch.device(device)
    elif isinstance(device, int):
        device = torch.device("cuda", device)

    def cb() -> None:
        idx = device.index
        if idx is None:
            idx = torch.cuda.current_device()
        default_generator = torch.cuda.default_generators[idx]
117
118
119
        if graph_safe_rng_available() and graph_safe:
            default_generator.graphsafe_set_state(new_state)
            return
Przemek Tredak's avatar
Przemek Tredak committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        default_generator.set_state(new_state)

    _lazy_call(cb)


def set_tensor_model_parallel_attributes(
    tensor: torch.Tensor, is_parallel: bool, dim: int, stride: int
) -> None:
    """set attributes needed for TP"""
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        assert not hasattr(tensor, attribute)
    # Set the attributes.
    setattr(tensor, "tensor_model_parallel", is_parallel)
    setattr(tensor, "partition_dim", dim)
    setattr(tensor, "partition_stride", stride)


137
@lru_cache
Przemek Tredak's avatar
Przemek Tredak committed
138
139
def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int:
    """Return world size for the distributed group."""
140
    if not torch.distributed.is_initialized():
Przemek Tredak's avatar
Przemek Tredak committed
141
142
143
144
        return 1
    return torch.distributed.get_world_size(group=group)


145
@lru_cache
Przemek Tredak's avatar
Przemek Tredak committed
146
147
def get_distributed_rank(group: Optional[dist_group_type] = None) -> int:
    """Return my rank for the distributed group."""
148
    assert torch.distributed.is_initialized(), "torch.distributed is not initialized."
Przemek Tredak's avatar
Przemek Tredak committed
149
150
151
152
153
154
155
    return torch.distributed.get_rank(group=group)


def initialize_affine_weight_gpu(
    weight: torch.Tensor,
    init_method: Callable,
    get_rng_state_tracker: Callable,
156
    partition_dim: int = 0,
Przemek Tredak's avatar
Przemek Tredak committed
157
    stride: int = 1,
158
    set_tp_attributes: bool = True,
Przemek Tredak's avatar
Przemek Tredak committed
159
160
161
) -> None:
    """Initialize affine weight for model parallel on GPU."""

162
163
164
165
    if set_tp_attributes:
        set_tensor_model_parallel_attributes(
            tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
        )
Przemek Tredak's avatar
Przemek Tredak committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

    if get_rng_state_tracker is None:
        init_method(weight)
        return

    with get_rng_state_tracker().fork():
        init_method(weight)


def split_tensor_into_1d_equal_chunks(
    tensor: torch.Tensor, tp_group: dist_group_type, new_buffer: bool = False
) -> torch.Tensor:
    """Break a tensor into equal 1D chunks."""
    partition_size = torch.numel(tensor) // get_distributed_world_size(tp_group)
    start_index = partition_size * get_distributed_rank(tp_group)
    end_index = start_index + partition_size
    if new_buffer:
        data = torch.empty(
            partition_size,
            dtype=tensor.dtype,
            device=torch.cuda.current_device(),
            requires_grad=False,
        )
        data.copy_(tensor.view(-1)[start_index:end_index])
    else:
        data = tensor.view(-1)[start_index:end_index]
    return data


195
def gather_split_1d_tensor(tensor: torch.Tensor, tp_group: dist_group_type) -> torch.Tensor:
Przemek Tredak's avatar
Przemek Tredak committed
196
197
198
199
200
201
202
203
    """Opposite of above function, gather values from model parallel ranks."""
    numel_gathered = torch.numel(tensor) * get_distributed_world_size(tp_group)
    gathered = torch.empty(
        numel_gathered,
        dtype=tensor.dtype,
        device=torch.cuda.current_device(),
        requires_grad=False,
    )
204
    torch.distributed.all_gather_into_tensor(gathered, tensor, group=tp_group)
Przemek Tredak's avatar
Przemek Tredak committed
205
206
207
    return gathered


208
class activation_recompute_forward(AbstractContextManager, ContextDecorator):
209
210
211
212
213
214
215
    """Context manager used to control the forward runtime behavior when executed
    under the `CheckpointFunction` function. For running FP8, the forward pass will
    run without storing intermediate activations. Instead, the forward pass saves
    the inputs tuple and the calling function. In the backwards pass, these are
    retrieved, and the forward pass is computed again while tracking the intermediate
    activations, followed by calculation of gradients using these values.
    """
216

217
218
    _is_first_fp8_module: List = []

219
    def __init__(self, activation_recompute: bool = False, recompute_phase: bool = False):
220
221
222
223
224
225
        super().__init__()
        self.activation_recompute = activation_recompute
        self.recompute_phase = recompute_phase

    def __enter__(self):
        global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
226
        _FP8_ACTIVATION_RECOMPUTE_ENABLED = (
227
228
229
230
            self.activation_recompute and FP8GlobalStateManager.is_fp8_enabled()
        )
        _FP8_ACTIVATION_RECOMPUTE_PHASE = self.recompute_phase

231
232
233
234
235
236
237
238
239
        if self.activation_recompute and not self.recompute_phase:
            activation_recompute_forward._is_first_fp8_module.append(
                FP8GlobalStateManager.IS_FIRST_FP8_MODULE
            )
        if self.activation_recompute and self.recompute_phase:
            FP8GlobalStateManager.IS_FIRST_FP8_MODULE = (
                activation_recompute_forward._is_first_fp8_module.pop(0)
            )

240
241
    def __exit__(self, *exc_details):
        global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
242
243
244
245
246
247
248
249
250
251
252
253
254
255
        _FP8_ACTIVATION_RECOMPUTE_ENABLED = False
        _FP8_ACTIVATION_RECOMPUTE_PHASE = False


def is_fp8_activation_recompute_enabled() -> bool:
    """Return global boolean"""
    return _FP8_ACTIVATION_RECOMPUTE_ENABLED


def in_fp8_activation_recompute_phase() -> bool:
    """Return global boolean"""
    return _FP8_ACTIVATION_RECOMPUTE_PHASE


256
257
258
259
260
261
262
def _get_active_autocast_contexts():
    """
    Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state
    at the time of this function's execution.
    """
    autocast_cached = torch.is_autocast_cache_enabled()

263
264
    gpu_autocast_enabled = torch.is_autocast_enabled()
    gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
265
    gpu_autocast_ctx = torch.cuda.amp.autocast(
266
267
        gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
    )
268

269
270
    cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
    cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
271
    cpu_autocast_ctx = torch.cpu.amp.autocast(
272
273
        cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
    )
274
275
276
277

    return gpu_autocast_ctx, cpu_autocast_ctx


278
class _CheckpointFunction(torch.autograd.Function):
Przemek Tredak's avatar
Przemek Tredak committed
279
280
281
282
283
284
285
286
287
288
289
290
    """This function is adapted from torch.utils.checkpoint with
    two main changes:
        1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
        2) the states in the model parallel tracker are also properly
           tracked/set/reset.
    """

    @staticmethod
    def forward(
        ctx,
        run_function: Callable,
        distribute_saved_activations: bool,
291
292
293
        get_rng_state_tracker: Union[Callable, None],
        tp_group: Union[dist_group_type, None],
        context_fn: Union[Callable, None],
294
        kwargs: Dict[str, Any],
Przemek Tredak's avatar
Przemek Tredak committed
295
296
        *args: Tuple[torch.Tensor, ...],
    ) -> Tuple[torch.Tensor, ...]:
297
298
        """Call forward function while saving state to be able to
        redo the computation later."""
Przemek Tredak's avatar
Przemek Tredak committed
299
300
301
302
303
        ctx.run_function = run_function
        ctx.distribute_saved_activations = distribute_saved_activations

        # Copy the rng states.
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
304
        ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
305
306
307
308
309
310
311
        if get_rng_state_tracker is not None:
            ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()

        if context_fn is not None:
            forward_ctx, recompute_ctx = context_fn()
        else:
            forward_ctx, recompute_ctx = noop_context_fn()
312
313
314
315

        # Preserve torch autocast context for the backward pass
        torch_gpu_amp_ctx, torch_cpu_amp_ctx = _get_active_autocast_contexts()

316
        with torch.no_grad(), forward_ctx:
317
            with activation_recompute_forward(activation_recompute=True, recompute_phase=False):
318
                outputs = run_function(*args, **kwargs)
Przemek Tredak's avatar
Przemek Tredak committed
319
320
321
322
323
324
325

        # Divide hidden states across model parallel group and only keep
        # the chunk corresponding to the current rank.
        if distribute_saved_activations:
            ctx.input_0_shape = args[0].data.shape
            safely_set_viewless_tensor_data(
                args[0],
326
                split_tensor_into_1d_equal_chunks(args[0].data, tp_group, new_buffer=True),
Przemek Tredak's avatar
Przemek Tredak committed
327
328
329
            )

        # Store everything.
330
331
332
333
        ctx.inputs = [arg if not torch.is_tensor(arg) else None for arg in args]
        tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args]
        ctx.save_for_backward(*tensor_inputs)

334
        fp8 = FP8GlobalStateManager.is_fp8_enabled()
335
        ctx.get_rng_state_tracker = get_rng_state_tracker
Przemek Tredak's avatar
Przemek Tredak committed
336
        ctx.tp_group = tp_group
337
        ctx.recompute_ctx = recompute_ctx
338
339
        ctx.torch_gpu_amp_ctx = torch_gpu_amp_ctx
        ctx.torch_cpu_amp_ctx = torch_cpu_amp_ctx
340
341
        ctx.fp8 = fp8
        ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
342
        ctx.kwargs = kwargs
Przemek Tredak's avatar
Przemek Tredak committed
343
344
345
346
347

        return outputs

    @staticmethod
    def backward(
348
        ctx, *args: Tuple[Union[torch.Tensor, None], ...]
Przemek Tredak's avatar
Przemek Tredak committed
349
    ) -> Tuple[Union[torch.Tensor, None], ...]:
350
        """Call backward function with activation recomputation."""
Przemek Tredak's avatar
Przemek Tredak committed
351
352
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
353
                "Checkpointing is not compatible with .grad(), please use .backward() if possible"
Przemek Tredak's avatar
Przemek Tredak committed
354
            )
355
356

        inputs = tuple(
357
            t if t is not None else arg for (t, arg) in zip(ctx.saved_tensors, ctx.inputs)
358
359
        )

360
        get_rng_state_tracker = ctx.get_rng_state_tracker
Przemek Tredak's avatar
Przemek Tredak committed
361
362
363
364

        if ctx.distribute_saved_activations:
            safely_set_viewless_tensor_data(
                inputs[0],
365
                gather_split_1d_tensor(inputs[0].data, ctx.tp_group).view(ctx.input_0_shape),
Przemek Tredak's avatar
Przemek Tredak committed
366
367
368
369
            )

        # Store the current states.
        bwd_cpu_rng_state = torch.get_rng_state()
370
        bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
371
372
        if get_rng_state_tracker is not None:
            bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
Przemek Tredak's avatar
Przemek Tredak committed
373
374
375

        # Set the states to what it used to be before the forward pass.
        torch.set_rng_state(ctx.fwd_cpu_rng_state)
376
        _set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False)
377
378
        if get_rng_state_tracker is not None:
            get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
Przemek Tredak's avatar
Przemek Tredak committed
379
380
381

        # Compute the forward pass.
        detached_inputs = detach_variable(inputs)
382
383
        with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward(
            activation_recompute=True, recompute_phase=True
384
385
        ), fp8_autocast(
            enabled=ctx.fp8, fp8_recipe=ctx.fp8_recipe
386
        ):
387
            outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
Przemek Tredak's avatar
Przemek Tredak committed
388
389
390

        # Set the states back to what it was at the start of this function.
        torch.set_rng_state(bwd_cpu_rng_state)
391
        _set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False)
392
393
        if get_rng_state_tracker is not None:
            get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker)
Przemek Tredak's avatar
Przemek Tredak committed
394
395
396

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
397
398
399
400
401
402
403
404
405

        outputs_with_grad = []
        args_with_grad = []
        for i, output in enumerate(outputs):
            if torch.is_tensor(output) and output.requires_grad:
                outputs_with_grad.append(output)
                args_with_grad.append(args[i])
        if len(outputs_with_grad) == 0:
            raise RuntimeError(
406
                "none of output has requires_grad=True, this checkpoint() is not necessary"
407
408
            )

409
410
411
        # backward does not require entering autocast context because
        # backward implementations already retrieve fp8 recipe and
        # enablement from stored ctx.
412
        torch.autograd.backward(outputs_with_grad, args_with_grad)
Przemek Tredak's avatar
Przemek Tredak committed
413
        grads = tuple(
414
            inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs
Przemek Tredak's avatar
Przemek Tredak committed
415
        )
416
417
        return (None, None, None, None, None, None) + grads

418

419
420
421
422
class _CheckpointFrame:
    """
    Storage frame for forward RNG states and detached activations from the forward recompute.
    """
423
424

    def __init__(self, recompute_fn: Callable, get_rng_state_tracker: Callable):
425
426
427
428
429
430
431
432
433
434
435
        self.recompute_fn = recompute_fn
        self.recomputed = []
        self.count = 0
        self.get_rng_state_tracker = get_rng_state_tracker
        self.fwd_rng_states = None
        self.bwd_rng_states = None

    def cache_rng_states(self, forward=True):
        """Cache fwd/bwd RNG states in the frame to restore later."""
        rng_states = (
            torch.get_rng_state(),
436
            _get_cuda_rng_state(graph_safe=False),
437
438
        )
        if self.get_rng_state_tracker is not None:
439
            rng_states += (self.get_rng_state_tracker().get_states(),)
440
441
442
443
444
445
446
447
448
449
450
451
452
453

        if forward:
            self.fwd_rng_states = rng_states
        else:
            self.bwd_rng_states = rng_states

    def restore_rng_states(self, forward=True):
        """Restore fwd/bwd RNG states that were previously cached into the frame."""
        if forward:
            rng_states = self.fwd_rng_states
        else:
            rng_states = self.bwd_rng_states

        torch.set_rng_state(rng_states[0])
454
        _set_cuda_rng_state(rng_states[1], graph_safe=False)
455
456
457
458
        if self.get_rng_state_tracker is not None:
            self.get_rng_state_tracker().set_states(rng_states[2])


459
460
461
class _recomputation_hook(
    torch.autograd.graph.saved_tensors_hooks
):  # pylint: disable=too-few-public-methods
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
    """torch.autograd hook for packing/unpacking tensors during the activation recompute phase."""

    def __init__(self, frame):

        def pack_hook(x):
            """
            Packing hook for each recomputed activation passed into the `ctx.save_for_backward()`
            call in the forward recomputation.
            """
            frame.recomputed.append(x.detach())
            return x.detach()

        def unpack_hook(x):
            """
            No-op unpack hook that will never be called because the backward pass for the
            forward recomputation is never triggered.
            """
            return x

        super().__init__(pack_hook, unpack_hook)


484
485
486
class _checkpoint_hook(
    torch.autograd.graph.saved_tensors_hooks
):  # pylint: disable=too-few-public-methods
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
    """torch.autograd hook for packing/unpacking tensors during the checkpointed forward pass."""

    def __init__(self, frame, args, kwargs):

        def pack_hook(x):
            """
            Packing hook for each tensor passed into `ctx.save_for_backward()` call in the
            forward pass. Since this is the first forward pass, we discard the tensor and instead
            pack a placeholder tensor index into the autograd engine context.
            """
            del x
            idx = frame.count
            frame.count += 1
            return idx

        def unpack_hook(idx):
            """
            Unpacking hook for each tensor that comes out of the `ctx.saved_tensors` call in the
            backward pass. The first time this is called, the _recomputation_hook will save all the
            activation tensors from `ctx.save_for_backward()` in the forward recomputation into the
            _CheckpointFrame. Subsequent calls will simply return the already recomputed activation
            tensor at the given index of the _CheckpointFrame storage.
            """

            if not frame.recomputed:
                # Store current RNG states in the backward pass
                frame.cache_rng_states(forward=False)

                # Set RNG states to what we saved before the forward pass
                frame.restore_rng_states(forward=True)

                # Recompute the forward pass
                with _recomputation_hook(frame):
                    frame.recompute_fn(*args, **kwargs)

                # Restore RNG states back to the backward pass
                frame.restore_rng_states(forward=False)

            # Return the already recomputed activation tensor at the given index
            activation = frame.recomputed[idx]
            frame.recomputed[idx] = None
            return activation

        super().__init__(pack_hook, unpack_hook)


def use_reentrant_activation_recompute():
    """Returns `True` if activation recompute is using the 'reentrant' method."""
    return _USE_REENTRANT_ACTIVATION_RECOMPUTE


def get_activation_recompute_contexts():
    """Returns context objects for the checkpointed forward pass and the forward recompute phase."""
    forward_ctx = activation_recompute_forward(
        activation_recompute=True,
        recompute_phase=False,
    )
    recompute_ctx = activation_recompute_forward(
        activation_recompute=True,
        recompute_phase=True,
    )
    return forward_ctx, recompute_ctx


551
def has_te_modules(network):
552
    """
553
    Check if there are any Transformer Engine modules in the network.
554
555
556
557
558
    """
    from .module import LayerNorm, RMSNorm
    from .module.base import TransformerEngineBaseModule
    from .attention import UnfusedDotProductAttention, DotProductAttention, MultiheadAttention
    from .transformer import TransformerLayer
559

560
561
562
563
564
565
566
567
568
    te_classes_list = [
        LayerNorm,
        RMSNorm,
        TransformerEngineBaseModule,
        UnfusedDotProductAttention,
        DotProductAttention,
        MultiheadAttention,
        TransformerLayer,
    ]
569
570
571
572
573

    if isinstance(network, torch.nn.Module):
        for module in network.modules():
            if any(isinstance(module, te_class) for te_class in te_classes_list):
                return True
574
        return False
575

576
577
578
    # Cannot check for TE modules inside a custom class/callable that's not a torch.nn.Module,
    # so just assume that it has TE modules just to be safe.
    return True
Przemek Tredak's avatar
Przemek Tredak committed
579

580

581
@torch._disable_dynamo
Przemek Tredak's avatar
Przemek Tredak committed
582
583
584
def checkpoint(
    function: Callable,
    *args: Tuple[torch.Tensor, ...],
585
    **kwargs: Dict[str, Any],
Przemek Tredak's avatar
Przemek Tredak committed
586
) -> Tuple[torch.Tensor, ...]:
587
588
589
590
591
592
593
594
595
596
597
    """
    Checkpoint a part of the model by trading compute for memory. This function is based on
    `torch.utils.checkpoint.checkpoint <https://pytorch.org/docs/stable/checkpoint.html>`_.

    .. warning::

        It is the user's responsibility to ensure identical behavior when calling
        :attr:`function` from the forward and backward pass. If different output is
        produced (e.g. due to global state), then the checkpointed version won't
        be numerically equivalent.

598
599
600
601
    .. warning::
        `use_reentrant=False` does not support early stopping, and will execute the entire forward
        pass for the checkpointed module when recomputing activations in the backward pass.

602
603
604
    Parameters
    ----------
    function: Callable
605
606
            pytorch module used to run the forward and backward passes using
            the specified :attr:`args` and :attr:`kwargs`.
607
608
609
610
611
    distribute_saved_activations: bool, default = False
            if set to `True` and `use_reentrant=True`, first tensor argument is distributed
            across the specified tensor parallel group (`tp_group`) before saving it for the
            backward pass. This has no effect when `use_reentrant=False`.
    get_rng_state_tracker: `Callable`, default = None
612
            python callable which returns an instance of :func:`CudaRNGStatesTracker`.
613
614
615
616
617
    tp_group : ProcessGroup, default = None
            tensor parallel process group. Used only when `distribute_saved_activations=True`
            and `use_reentrant=True`. If `None`, it falls back to the default group.
    use_reentrant : bool, default = True
            perform checkpointing in reentrant mode.
618
619
620
621
622
    args : tuple
            tuple of torch tensors for inputs to :attr:`function`.
    kwargs : dict
            dictionary of string keys for keyword arguments to :attr:`function`.
    """
623
624
625
626
627
628
629
    # Pop out te.distributed.checkpoint() arguments
    global _USE_REENTRANT_ACTIVATION_RECOMPUTE
    _USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs.pop("use_reentrant", True)
    distribute_saved_activations = kwargs.pop("distribute_saved_activations", False)
    tp_group = kwargs.pop("tp_group", None)
    get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None)

630
    # Ensure backward compatibility.
631
632
633
634
635
636
    if (
        len(args) > 3
        and isinstance(args[0], bool)
        and callable(args[1])
        and isinstance(args[2], None | dist_group_type)
    ):
637
638
639
640
        warnings.warn(
            "Passing non-tensor non-keyword arguments is deprecated and support will be removed in "
            "future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and "
            "`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.",
641
642
            DeprecationWarning,
            stacklevel=2,
643
        )
644
645
646
        distribute_saved_activations = args[0]
        get_rng_state_tracker = args[1]
        tp_group = args[2]
647
648
        args = args[3:]

649
650
    # Trigger the native PyTorch checkpoint if the function is not or does not contain a
    # Transformer Engine module.
651
652
653
    context_fn = kwargs.pop("context_fn", noop_context_fn)
    determinism_check = kwargs.pop("determinism_check", "default")
    debug = kwargs.pop("debug", False)
654
    if not has_te_modules(function):
655
656
657
658
659
660
661
        return torch.utils.checkpoint.checkpoint(
            function,
            *args,
            use_reentrant=_USE_REENTRANT_ACTIVATION_RECOMPUTE,
            context_fn=context_fn,
            determinism_check=determinism_check,
            debug=debug,
662
            **kwargs,
663
664
        )

665
666
667
668
669
670
671
    from .module.base import TransformerEngineBaseModule

    if isinstance(function, TransformerEngineBaseModule):
        # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
        # to scatter/gather activations that we will recompute anyway.
        setattr(function, "fsdp_wrapped", False)
        setattr(function, "fsdp_group", None)
672

673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
    # Otherwise discard unused te.utils.checkpoint.checkpoint() arguments
    # and execute TE's own checkpointing
    # NOTE: This logic uses the TE checkpoint on all custom callable `function` handles because we
    #       cannot be sure there are no TE modules inside the function. It also means we might run
    #       the TE checkpoint for non-TE modules, so the TE checkpoint has to support a potential
    #       user context function.
    del determinism_check, debug
    if _USE_REENTRANT_ACTIVATION_RECOMPUTE:
        # If saved activations need to be distributed but there is no process group,
        # default to the world group.
        if distribute_saved_activations:
            assert torch.distributed.is_initialized(), "torch.distributed is not initialized."
            tp_group = torch.distributed.GroupMember.WORLD if tp_group is None else tp_group

        return _CheckpointFunction.apply(
            function,
            distribute_saved_activations,
            get_rng_state_tracker,
            tp_group,
            context_fn,
            kwargs,
            *args,
        )

    if distribute_saved_activations:
        warnings.warn(
            "`distribute_saved_activations=True` has no effect when `use_reentrant=False`. "
            "The non-reentrant checkpoint implementation does not manually store forward "
            "inputs for the activation recompute in the backward pass, and instead leverages "
            "the autograd engine's pack/unpack hooks."
        )

    user_forward_ctx, user_recompute_ctx = context_fn()
    te_forward_ctx, te_recompute_ctx = get_activation_recompute_contexts()

708
709
710
    # Preserve the torch autocast contexts from the forward pass during recompute phase.
    torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts()

711
712
713
    fp8 = FP8GlobalStateManager.is_fp8_enabled()
    fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None

714
    def recompute_fn(*args, **kwargs):
715
716
        with torch.autograd.enable_grad(), (
            te_recompute_ctx
717
718
719
        ), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx, fp8_autocast(
            enabled=fp8, fp8_recipe=fp8_recipe
        ):
720
721
722
723
724
725
            function(*args, **kwargs)

    # Initialize a new checkpoint frame for each new forward pass.
    new_frame = _CheckpointFrame(
        recompute_fn,
        get_rng_state_tracker,
Przemek Tredak's avatar
Przemek Tredak committed
726
    )
727
728
    new_frame.cache_rng_states(forward=True)

729
    with _checkpoint_hook(new_frame, args, kwargs), te_forward_ctx, user_forward_ctx:
730
        out = function(*args, **kwargs)
Przemek Tredak's avatar
Przemek Tredak committed
731

732
    return out
Przemek Tredak's avatar
Przemek Tredak committed
733

734

735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
class CudaRNGStatesTracker:
    """
    For model parallelism, multiple RNG states need to simultaneously exist in order
    to execute operations in or out of the model parallel region. This class keeps
    track of the various RNG states and provides utility methods to maintain them and
    execute parts of the model under a given RNG setting. Using the `add` method, a
    cuda rng state is initialized based on the input `seed` and is assigned to `name`.
    Later, by forking the rng state, we can perform operations and return to our starting
    cuda state.
    """

    def __init__(self):
        # Map from a string name to the cuda rng state.
        self.states_ = {}
        # Seeds are just for book keeping and ensure no seed is set twice.
        self.seeds_ = set()

    def reset(self):
        """
        Set to the initial state (no tracker).
        """
        self.states_ = {}
        self.seeds_ = set()

    def get_states(self) -> Dict[str, torch.Tensor]:
        """
        Get rng states. Copy the dictionary so we have direct pointers
        to the states, not just a pointer to the dictionary.
        """
        states = {}
        for name in self.states_:
            states[name] = self.states_[name]
        return states

    def set_states(self, states: Dict[str, torch.Tensor]) -> None:
        """
        Set the rng states. For efficiency purposes, we do not
        check the size of seed for compatibility.

        states: Dict[str, torch.Tensor]
               A mapping from string names to RNG states.
        """
        self.states_ = states

    def add(self, name: str, seed: int) -> None:
        """
        Adds a new RNG state.

        name: str
             string identifier for the RNG state.
        seed: int
             PyTorch seed for the RNG state.
        """
        # Check seed is not already used.
        if seed in self.seeds_:
790
            raise RuntimeError(f"seed {seed} already exists")
791
792
793
        self.seeds_.add(seed)
        # Check that state is not already defined.
        if name in self.states_:
794
            raise RuntimeError(f"cuda rng state {name} already exists")
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811

        if graph_safe_rng_available():
            new_state = _get_cuda_rng_state(clone=True)
            new_state.manual_seed(seed)
            self.states_[name] = new_state
            # Update global states.
            set_all_rng_states(self.states_)
        else:
            # Get the current rng state.
            orig_rng_state = _get_cuda_rng_state()
            # Set the new state and store it.
            torch.cuda.manual_seed(seed)
            self.states_[name] = _get_cuda_rng_state(clone=True)
            # Reset rng state to what it was.
            _set_cuda_rng_state(orig_rng_state)
            # Update global states.
            set_all_rng_states(self.states_)
812
813

    @contextmanager
814
    def fork(self, name: str = "model-parallel-rng"):
815
816
817
818
819
820
821
822
823
        """
        Fork the cuda rng state, perform operations, and exit with
        the original state.

        name: str
             string identifier for the RNG state.
        """
        # Check if we have added the state
        if name not in self.states_:
824
            raise KeyError(f"cuda rng state {name} is not added")
825
826
        # Get the reference to current rng state.
        orig_cuda_rng_state = _get_cuda_rng_state()
827
828
829
830
831
832
        # Set rng state to the desired one
        _set_cuda_rng_state(self.states_[name])
        # Do the stuff we wanted to do.
        try:
            yield
        finally:
833
834
835
            # this is redundant with graph-safe API
            if not graph_safe_rng_available():
                self.states_[name] = _get_cuda_rng_state()
836
837
838
839
            # And set the state to the original state we started with.
            _set_cuda_rng_state(orig_cuda_rng_state)


Przemek Tredak's avatar
Przemek Tredak committed
840
def reduce_scatter_along_first_dim(
841
    inp: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
842
) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]:
Przemek Tredak's avatar
Przemek Tredak committed
843
844
845
846
    """Reduce-scatter the input tensor across model parallel group."""
    world_size = get_distributed_world_size(tp_group)
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
847
        return inp, None
Przemek Tredak's avatar
Przemek Tredak committed
848

849
    dim_size = list(inp.size())
Przemek Tredak's avatar
Przemek Tredak committed
850
851
852
853
854
855
    assert (
        dim_size[0] % world_size == 0
    ), "First dimension of the tensor should be divisible by tensor parallel size"

    dim_size[0] = dim_size[0] // world_size

856
    output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device())
857
    handle = torch.distributed.reduce_scatter_tensor(
858
        output, inp.contiguous(), group=tp_group, async_op=async_op
Przemek Tredak's avatar
Przemek Tredak committed
859
860
861
862
    )
    return output, handle


863
def _all_gather_fp8(
864
    inp: torch.Tensor,
865
866
867
    process_group: dist_group_type,
    *,
    async_op: bool = False,
868
    quantizer: Optional[Quantizer] = None,
869
870
871
872
873
    out_shape: Optional[list[int]] = None,
) -> tuple[Float8TensorBase, Optional[torch.distributed.Work]]:
    """All-gather FP8 tensor along first dimension."""
    world_size = get_distributed_world_size(process_group)

874
875
876
877
878
879
    # Check that quantizer is valid
    if quantizer is not None and not isinstance(
        quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
    ):
        raise ValueError(f"Got non-FP8 quantizer ({quantizer.__class__.__name__})")

880
881
    # Output tensor dims
    if out_shape is None:
882
        out_shape = list(inp.size())
883
884
        out_shape[0] *= world_size

885
886
887
    # Cast input tensor to FP8 if needed
    # Note: We cannot directly all-gather the transposed FP8 tensor,
    # so temporarily modify quantizer to avoid creating FP8 transpose.
888
    if not isinstance(inp, Float8TensorBase):
889
890
        if quantizer is None:
            raise ValueError("Input tensor is not FP8 and no quantizer was provided")
891
        init_rowwise_usage = quantizer.rowwise_usage
892
        init_columnwise_usage = quantizer.columnwise_usage
893
        quantizer.set_usage(rowwise=True, columnwise=False)
894
        inp = quantizer(inp)
895
896
897
898
        quantizer.set_usage(
            rowwise=init_rowwise_usage,
            columnwise=init_columnwise_usage,
        )
899
900
901

    # Construct output tensor
    out: Float8TensorBase
902
    if quantizer is not None:
903
904
        dtype = torch.float32
        device = "cuda"
905
906
907
        if isinstance(inp, Float8Tensor):
            dtype = inp.dtype
            device = inp.device
908
        out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
909
910
    elif isinstance(inp, Float8Tensor):
        out = inp.make_like(inp, shape=out_shape)
911
912
913
        out._data = torch.empty_like(
            out_shape,
            dtype=torch.uint8,
914
            device=inp.device,
915
916
917
918
919
        )
        out._transpose = None
        out._transpose_invalid = True
    else:
        raise RuntimeError("FP8TensorBase is not supported yet without Quantizer")
920
921

    # Assume scaling factors are identical across ranks
922
    out._scale_inv = inp._scale_inv
923
924
925
926

    # Perform communication
    handle = torch.distributed.all_gather_into_tensor(
        out._data,
927
        inp._data.contiguous(),
928
929
930
931
932
        group=process_group,
        async_op=async_op,
    )

    # Make sure FP8 transpose is populated if needed
933
934
935
936
    needs_transpose = (
        quantizer is not None and quantizer.columnwise_usage and not non_tn_fp8_gemm_supported()
    )
    if needs_transpose:
937
938
939
940
941
942
943
944
        if handle is not None:
            handle.wait()
            handle = None
        out._create_transpose()

    return out, handle


945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
def _all_gather_fp8_blockwise(
    inp: torch.Tensor,
    process_group: dist_group_type,
    *,
    async_op: bool = False,  # pylint: disable=unused-argument
    quantizer: Optional[Quantizer] = None,
    out_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]:
    """
    All-gather FP8 tensor along first dimension for blockwise quantization.

    Returns: quantizer(gather(inp))

    NOTE: The implementation is not sophisticated enough to honor async_op=True.
    In some cases it falls back to synchronous gather and invokes the quantizer.
    """

    # Input tensor attributes
    device: torch.device
    dtype: torch.dtype
    if isinstance(inp, torch.Tensor):
        device = inp.device
        dtype = inp.dtype
    elif isinstance(inp, Float8BlockwiseQTensorBase):
        if inp._rowwise_data is not None:
            device = inp._rowwise_data.device
        elif inp._columnwise_data is not None:
            device = inp._columnwise_data.device
        else:
            raise ValueError("Got Float8BlockwiseQTensorBase input tensor without any data")
        dtype = torch.bfloat16  # Only has fp8 dtype. Guess BF16 for dequant.
    else:
        raise ValueError(
            "Invalid type for input tensor (expected torch.Tensor or Float8BlockwiseQTensorBase, "
            f"found {inp.__class__.__name__})"
        )
    world_size = get_distributed_world_size(process_group)

    # Check that quantizer is valid
    if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer):
        raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})")
    if not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128):
        raise NotImplementedError("Only 1D blockwise quantization is supported for allgather")

    # Output tensor dims
    if out_shape is None:
        out_shape = list(inp.size())
        out_shape[0] *= world_size

    # Doing BF16 gather for now as baseline because it's simpler
    if not isinstance(inp, Float8BlockwiseQTensorBase) and quantizer is not None:
        out = torch.empty(
            out_shape,
            dtype=dtype,
            device=device,
            memory_format=torch.contiguous_format,
        )
        torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
        out = quantizer(out)
        return out, None
    # Implementation of fp8 gather needs to account for:
    # * Getting columnwise data as a transpose of how it is stored for GEMMS.
    # * Gathering non GEMM swizzled scales.
    # * Refer to scaffold code when implementing at:
    # https://github.com/kwyss-nvidia/TransformerEngine/commit/6659ee9dc84fb515d1d47699d8bfd20a72b76477
    raise NotImplementedError("fp8 blockwise allgather not yet implemented")


1013
def _all_gather_mxfp8(
1014
    inp: torch.Tensor,
1015
1016
1017
1018
1019
1020
1021
1022
    process_group: dist_group_type,
    *,
    async_op: bool = False,
    quantizer: MXFP8Quantizer,
    out_shape: Optional[list[int]] = None,
) -> tuple[MXFP8TensorBase, Optional[torch.distributed.Work]]:
    """All-gather MXFP8 tensor along first dimension."""

1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
    # Input tensor attributes
    in_shape: Iterable[int]
    device: torch.device
    dtype: torch.dtype
    if isinstance(inp, torch.Tensor):
        in_shape = inp.size()
        device = inp.device
        dtype = inp.dtype
    elif isinstance(inp, MXFP8TensorBase):
        if inp._rowwise_data is not None:
            in_shape = inp._rowwise_data.device.size()
            device = inp._rowwise_data.device
            dtype = inp._rowwise_data.dtype
        elif inp._columnwise_data is not None:
            in_shape = inp._columnwise_data.device.size()
            device = inp._columnwise_data.device
            dtype = inp._columnwise_data.dtype
        else:
            raise ValueError("Got MXFP8 input tensor without any data")
        dtype = torch.bfloat16
    else:
        raise ValueError(
            "Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, "
            f"found {inp.__class__.__name__})"
        )

    # Output tensor shape
1050
1051
1052
1053
    world_size = get_distributed_world_size(process_group)
    if out_shape is None:
        out_shape = [in_shape[0] * world_size] + in_shape[1:]

1054
1055
1056
1057
1058
1059
1060
1061
1062
    # For cases where inp has dimensions that cannot be quantized,
    # we gather in high precision followed by a cast to FP8.
    if (
        not isinstance(inp, MXFP8TensorBase)
        and quantizer is not None
        and not quantizer.is_quantizable(inp)
    ):
        out = torch.empty(
            out_shape,
1063
1064
            dtype=dtype,
            device=device,
1065
1066
1067
1068
1069
1070
1071
1072
1073
            memory_format=torch.contiguous_format,
        )
        torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
        out = quantizer(out)
        return out, None

    # Cast input tensor to MXFP8 with required data
    if not isinstance(inp, MXFP8TensorBase):
        inp = quantizer(inp)
1074
1075
    elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
        quantizer.columnwise_usage and inp._columnwise_data is None
1076
1077
1078
1079
1080
1081
    ):
        warnings.warn(
            "Input and quantizer do not have matching usages. "
            "Dequantizing and requantizing to MXFP8."
        )
        inp = quantizer(inp.dequantize())
1082

1083
    # Construct MXFP8 output tensor
1084
    out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
1085

1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
    # Coalesce NCCL collectives
    with torch.distributed._coalescing_manager(
        group=process_group,
        device=device,
        async_ops=async_op,
    ) as coalescing_manager:

        # Gather MXFP8 data for row-wise usage
        if quantizer.rowwise_usage:

            # Remove padding from MXFP8 scale-inverses
            in_scale_inv = inp._rowwise_scale_inv
            out_scale_inv = out._rowwise_scale_inv
            flattened_in_shape0 = math.prod(in_shape[:-1])
            if in_scale_inv.size(0) != flattened_in_shape0:
                in_scale_inv = in_scale_inv[:flattened_in_shape0]
                out_scale_inv[flattened_in_shape0 * world_size :].zero_()
                out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]

            # Launch all-gathers
            torch.distributed.all_gather_into_tensor(
                out_scale_inv,
                in_scale_inv,
                group=process_group,
            )
            torch.distributed.all_gather_into_tensor(
                out._rowwise_data,
                inp._rowwise_data,
                group=process_group,
            )
1116

1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
        # Gather MXFP8 data for column-wise usage
        if quantizer.columnwise_usage:

            # Remove padding from MXFP8 scale-inverses
            in_scale_inv = inp._columnwise_scale_inv
            out_scale_inv = out._columnwise_scale_inv
            flattened_in_shape0 = math.prod(in_shape[:-1]) // 32
            if in_scale_inv.size(0) != flattened_in_shape0:
                in_scale_inv = in_scale_inv[:flattened_in_shape0]
                out_scale_inv[flattened_in_shape0 * world_size :].zero_()
                out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]

            # Launch all-gathers
            torch.distributed.all_gather_into_tensor(
                out_scale_inv,
                in_scale_inv,
                group=process_group,
            )
            torch.distributed.all_gather_into_tensor(
                out._columnwise_data,
                inp._columnwise_data,
                group=process_group,
            )
1140

1141
    handle = coalescing_manager if async_op else None
1142
    return out, handle
1143
1144


Przemek Tredak's avatar
Przemek Tredak committed
1145
def gather_along_first_dim(
1146
    inp: torch.Tensor,
1147
1148
    process_group: dist_group_type,
    async_op: bool = False,
1149
1150
    quantizer: Optional[Quantizer] = None,
) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]:
1151
1152
1153
    """
    All-gather tensors and concatenate along first dimension.
    """
Przemek Tredak's avatar
Przemek Tredak committed
1154

1155
1156
    # Return immediately if no communication is required
    world_size = get_distributed_world_size(process_group)
Przemek Tredak's avatar
Przemek Tredak committed
1157
    if world_size == 1:
1158
1159
1160
        if quantizer is not None and not isinstance(inp, QuantizedTensor):
            inp = quantizer(inp)
        return inp, None
Przemek Tredak's avatar
Przemek Tredak committed
1161

1162
    # Output tensor dims
1163
    out_shape = list(inp.size())
1164
1165
    out_shape[0] *= world_size

1166
    # FP8 case: delayed scaling or current scaling
1167
    if isinstance(inp, Float8TensorBase) or isinstance(
1168
1169
        quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
    ):
1170
        return _all_gather_fp8(
1171
            inp,
1172
1173
1174
1175
            process_group,
            async_op=async_op,
            quantizer=quantizer,
            out_shape=out_shape,
1176
        )
1177

1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
    # FP8 block scaling case, block length = 128
    if isinstance(inp, Float8BlockwiseQTensorBase) or isinstance(quantizer, Float8BlockQuantizer):
        return _all_gather_fp8_blockwise(
            inp,
            process_group,
            async_op=async_op,
            quantizer=quantizer,
            out_shape=out_shape,
        )

1188
    # MXFP8 case
1189
    if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer):
1190
1191
        assert isinstance(quantizer, MXFP8Quantizer)
        return _all_gather_mxfp8(
1192
            inp,
1193
1194
1195
1196
1197
1198
            process_group,
            async_op=async_op,
            quantizer=quantizer,
            out_shape=out_shape,
        )

1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
    # Debug case - call gather_along_first_dim on each tensor
    if isinstance(inp, DebugQuantizedTensor):
        out_obj = inp
        rowwise = inp.get_tensor(False)
        columnwise = inp.get_tensor(True)
        final_quantizer = (
            None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer
        )
        rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0]
        out_obj.rowwise_gemm_tensor = rowwise_total
        if rowwise is not columnwise:
            final_quantizer_columnwise = (
                None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer
            )
            columnwise_total, _ = gather_along_first_dim(
                columnwise, process_group, False, final_quantizer_columnwise
            )
            out_obj.columnwise_gemm_tensor = columnwise_total
        else:
            out_obj.rowwise_gemm_tensor = out_obj.rowwise_gemm_tensor
        return out_obj, None

1221
1222
1223
1224
1225
1226
    # High-precision communication for quantized tensors
    if quantizer is not None:
        warnings.warn(
            "Attempting to all-gather an unsupported quantized tensor. "
            "Falling back to high-precision all-gather."
        )
1227
1228
        if isinstance(inp, QuantizedTensor):
            inp = inp.dequantize()
1229
1230
        out = torch.empty(
            out_shape,
1231
1232
            dtype=inp.dtype,
            device=inp.device,
1233
1234
            memory_format=torch.contiguous_format,
        )
1235
        torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
1236
1237
        out = quantizer(out)
        return out, None
Przemek Tredak's avatar
Przemek Tredak committed
1238

1239
    # Dequantize quantized tensor if not supported
1240
    if isinstance(inp, QuantizedTensor):
1241
1242
1243
1244
        warnings.warn(
            "Attempting to all-gather an unsupported quantized tensor. "
            "Falling back to high-precision all-gather."
        )
1245
        inp = inp.dequantize()
1246
1247
1248
1249

    # Communication for plain PyTorch tensors
    out = torch.empty(
        out_shape,
1250
1251
        dtype=inp.dtype,
        device=inp.device,
1252
1253
        memory_format=torch.contiguous_format,
    )
1254
    handle = torch.distributed.all_gather_into_tensor(
1255
        out,
1256
        inp.contiguous(),
1257
1258
        group=process_group,
        async_op=async_op,
Przemek Tredak's avatar
Przemek Tredak committed
1259
    )
1260
    return out, handle
Przemek Tredak's avatar
Przemek Tredak committed
1261
1262
1263


def allreduce(
1264
    inp: torch.Tensor,
Przemek Tredak's avatar
Przemek Tredak committed
1265
1266
    tp_group: Optional[dist_group_type] = None,
    async_op: bool = False,
1267
) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]:
Przemek Tredak's avatar
Przemek Tredak committed
1268
1269
1270
1271
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if get_distributed_world_size(tp_group) == 1:
1272
        return inp, None
Przemek Tredak's avatar
Przemek Tredak committed
1273
1274

    # All-reduce.
1275
    handle = torch.distributed.all_reduce(inp, group=tp_group, async_op=async_op)
Przemek Tredak's avatar
Przemek Tredak committed
1276

1277
    return inp, handle
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287


def _fsdp_scatter_tensors(
    fsdp_group: dist_group_type,
    *tensors: torch.Tensor,
):
    shapes = []
    if fsdp_group is not None:
        for t in tensors:
            if isinstance(t, torch.Tensor):
1288
1289
1290
1291
1292
1293
1294
                targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t]
                for target in targets:
                    shapes.append(target.data.shape)
                    safely_set_viewless_tensor_data(
                        target,
                        split_tensor_into_1d_equal_chunks(target.data, fsdp_group, new_buffer=True),
                    )
1295
1296
1297
1298
1299
1300
1301
            else:
                shapes.append(None)
    return shapes


def _fsdp_gather_tensors(
    fsdp_group: dist_group_type,
1302
    shapes: List[Tuple[int, ...]],
1303
1304
1305
1306
1307
1308
1309
    *tensors: torch.Tensor,
):
    if fsdp_group is not None:
        assert len(shapes) == len(tensors), "Number of tensors and tensor shapes must be equal."
        for s, t in zip(shapes, tensors):
            if isinstance(t, torch.Tensor):
                assert s is not None, "Internal TE error."
1310
1311
1312
1313
1314
                targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t]
                for target in targets:
                    safely_set_viewless_tensor_data(
                        target, gather_split_1d_tensor(target.data, fsdp_group).view(s)
                    )
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325


def _is_te_module(module):
    """
    Check if given module is a Transformer Engine module that requires the TE checkpoint
    implementation for activation recompute.
    """
    from .module import LayerNorm, RMSNorm
    from .module.base import TransformerEngineBaseModule
    from .attention import UnfusedDotProductAttention, DotProductAttention, MultiheadAttention
    from .transformer import TransformerLayer
1326

1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
    te_classes_list = [
        LayerNorm,
        RMSNorm,
        TransformerEngineBaseModule,
        UnfusedDotProductAttention,
        DotProductAttention,
        MultiheadAttention,
        TransformerLayer,
    ]
    is_te_module = False
    for te_class in te_classes_list:
        if isinstance(module, te_class):
            is_te_module = True
            break
    return is_te_module


def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
    """
    Inject FSDP process gorup references into FSDP-wrapped TE modules in an FSDP-wrapped root
    module in order to scatter/gather the Fp8 weight copies at the same time FSDP scatters/gathers
    its `FlatParameters`.

    Parameters
    ----------
    fsdp_root: torch.nn.Module
               FSDP-wrapped root module that may contain FSDP-wrapped TE modules.
    """
    assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped."

    # If the root module is a TE module, inject FSDP information into it
    if _is_te_module(fsdp_root.module):
1359
1360
1361
1362
1363
        if hasattr(fsdp_root, "primary_weights_in_fp8"):
            assert not fsdp_root.primary_weights_in_fp8, (
                "TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
                "Please initialize your model without the te.fp8_model_init(...) context."
            )
1364
1365
1366
1367
1368
1369
1370
1371
        root_state = _get_module_fsdp_state(fsdp_root)
        assert root_state is not None, "Root module does not have a valid _FSDPState."
        setattr(fsdp_root.module, "fsdp_group", root_state.process_group)

    # Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules
    fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root)
    for state, fsdp_module in zip(fsdp_states, fsdp_modules):
        if _is_te_module(fsdp_module.module):
1372
1373
1374
1375
1376
            if hasattr(fsdp_module.module, "primary_weights_in_fp8"):
                assert not fsdp_module.module.primary_weights_in_fp8, (
                    "TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
                    "Please initialize your model without the te.fp8_model_init(...) context."
                )
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
            setattr(fsdp_module.module, "fsdp_group", state.process_group)


class FullyShardedDataParallel(FSDP):
    """
    Transformer Engine wrapper around `torch.distributed.fsdp.FullyShardedDataParallel` that
    extracts necessary information out of the FSDP wrap for TE modules to scatter their
    activation tensors after each forward pass and gather them before the backward pass.
    """

    def __init__(self, module, *args, **kwargs):
        super().__init__(module, *args, **kwargs)
        prepare_te_modules_for_fsdp(self)