linear.py 74.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Linear API"""
6
from typing import Callable, Dict, Optional, Tuple, Union, List
7
8
from functools import reduce
from operator import mul as multiply_op
9
import warnings
10
11
12

import torch

13
import transformer_engine_torch as tex
14

15
from transformer_engine.common.recipe import Recipe
16
from transformer_engine.pytorch import torch_version
17

18
from .base import (
19
20
    fill_userbuffers_buffer_for_all_gather,
    get_dummy_wgrad,
21
    get_ub,
22
    get_workspace,
23
24
25
26
27
    TransformerEngineBaseModule,
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
28
from ._common import noop_cat, WeightGradStore
29
from ..quantization import FP8GlobalStateManager
30
31
from ..utils import (
    cast_if_needed,
32
    clear_tensor_data,
33
    divide,
34
    init_method_constant,
35
36
    requires_grad,
    needs_quantized_gemm,
37
    assert_dim_for_fp8_exec,
38
    assert_dim_for_all_gather,
39
40
    nvtx_range_pop,
    nvtx_range_push,
41
42
43
44
45
)
from ..distributed import (
    set_tensor_model_parallel_attributes,
    get_distributed_world_size,
    allreduce,
46
    symmetric_all_reduce,
47
48
    reduce_scatter_along_first_dim,
    gather_along_first_dim,
49
    is_fp8_activation_recompute_enabled,
50
    in_fp8_activation_recompute_phase,
51
52
    _fsdp_scatter_tensors,
    _fsdp_gather_tensors,
53
54
)
from ..cpp_extensions import (
55
    general_gemm,
56
)
57
from ..constants import GemmParallelModes, dist_group_type
58
from ..jit import no_torch_dynamo
59
from ..graph import is_graph_capturing
60
from ..quantized_tensor import (
61
    QuantizedTensor,
62
    QuantizedTensorStorage,
63
64
65
66
    Quantizer,
    prepare_for_saving,
    restore_from_saved,
)
67
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
68
from ..tensor.mxfp8_tensor import MXFP8Quantizer
69
from ..tensor.utils import is_custom
70
from ..export import is_in_onnx_export_mode, assert_warmed_up
71
72
73
74
75
76
from ..cpu_offload import (
    is_cpu_offload_enabled,
    start_offload,
    mark_not_offload,
    mark_activation_offload,
)
77
from ...debug.pytorch.debug_state import TEDebugState
78

79
80
81
82
83
84
85
86
87
88
89
__all__ = ["Linear"]


class _Linear(torch.autograd.Function):
    """Linear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
90
        weight: torch.Tensor,
91
        inp: torch.Tensor,
92
        bias: Optional[torch.Tensor],
93
94
95
        is_first_microbatch: Union[bool, None],
        fp8: bool,
        fp8_calibration: bool,
96
        wgrad_store: WeightGradStore,
97
98
99
100
        input_quantizer: Optional[Quantizer],
        weight_quantizer: Optional[Quantizer],
        output_quantizer: Optional[Quantizer],
        grad_input_quantizer: Optional[Quantizer],
101
102
        grad_weight_quantizer: Optional[Quantizer],
        grad_output_quantizer: Optional[Quantizer],
103
        fuse_wgrad_accumulation: bool,
104
        cpu_offloading: bool,
105
106
107
108
109
110
111
        tp_group: Union[dist_group_type, None],
        tp_size: int,
        sequence_parallel: bool,
        tensor_parallel: bool,
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
        is_grad_enabled: bool,
112
113
114
115
116
117
        ub_overlap_rs_fprop: bool,
        ub_overlap_ag_dgrad: bool,
        ub_overlap_ag_fprop: bool,
        ub_overlap_rs_dgrad: bool,
        ub_bulk_dgrad: bool,
        ub_bulk_wgrad: bool,
118
        ub_name: str,
119
        fp8_output: bool,  # pylint: disable=unused-argument
120
        fsdp_group: Union[dist_group_type, None],
121
122
        module: torch.nn.Module,
        skip_fp8_weight_update: bool,
123
        symmetric_ar_type: str,
124
        save_original_input: bool = False,
125
        debug: Optional[bool] = False,
126
    ) -> torch.Tensor:
127
        # pylint: disable=missing-function-docstring
128

129
130
131
132
133
        # NVTX label for profiling
        nvtx_label = "transformer_engine._Linear.forward"
        if ub_name is not None:
            nvtx_label = f"{nvtx_label}.{ub_name}"

134
        # Make sure input dimensions are compatible
135
        out_features, in_features = weight.shape
136
        assert inp.shape[-1] == in_features, "GEMM not possible"
137

138
        # Configure tensor-parallel communication
139
        tp_world_size = get_distributed_world_size(tp_group)
140
141
142
143
        backward_needs_input = is_grad_enabled and weight.requires_grad
        with_input_all_gather_nccl = (
            parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
        )
144
145

        # Configure Userbuffers communication (comm+GEMM overlap)
146
147
148
149
150
151
        if debug:  # turn off userbuffers in debug mode
            ub_overlap_rs_fprop = False
            ub_overlap_ag_fprop = False
            ub_overlap_rs_dgrad = False
            ub_bulk_wgrad = False
            ub_bulk_dgrad = False
152
153
154
        ub_obj = None
        ub_type = None
        if ub_overlap_rs_fprop:
155
            ub_obj = get_ub(ub_name + "_fprop", fp8)
156
157
            ub_type = tex.CommOverlapType.RS
        elif ub_overlap_ag_fprop:
158
            ub_obj = get_ub(ub_name + "_fprop", fp8)
159
160
            ub_type = tex.CommOverlapType.AG

161
162
        # custom recipe check
        custom = is_custom(input_quantizer) or is_custom(weight_quantizer)
163

164
165
166
167
168
169
170
171
        # ------------------------------------------------------
        # Prepare input tensor
        # Note: Cast to expected dtype and perform tensor-parallel communication
        # ------------------------------------------------------
        nvtx_range_push(f"{nvtx_label}.input_cast_comm")
        inputmat = inp  # Input tensor to save for backward (maybe sharded)
        inputmat_total = None  # Input tensor to pass to GEMM (gathered)
        own_quantized_input = False
172
        if fp8:
173
            assert_dim_for_fp8_exec(inputmat, weight)
174
            assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer)
175
176
177
178
179
            if save_original_input:
                assert not isinstance(
                    input_quantizer, Float8Quantizer
                ), "DelayedScaling recipe is not supported with save_original_input"

180
181
182
183
184
185
        if with_input_all_gather_nccl or ub_overlap_ag_fprop:  # All-gather input tensor

            # Cast local input tensor if needed
            if fp8 or debug:
                if input_quantizer is None:
                    raise ValueError("Missing quantizer for input tensor")
186
                if not isinstance(inputmat, QuantizedTensorStorage) and not custom:
187
188
                    own_quantized_input = True
                    input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
189
190
191
192
193
                    if isinstance(
                        input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
                    ):
                        # All-gather is not supported with FP8 column-wise data
                        input_quantizer.set_usage(columnwise=False)
194
195
196
197
198
                    if save_original_input:
                        # No need for column-wise data since this
                        # tensor will not be cached for backward pass
                        input_quantizer.set_usage(columnwise=False)
                        own_quantized_input = False
199
                    inputmat = input_quantizer(inputmat)
200
            else:
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
                inputmat = cast_if_needed(inp, activation_dtype)  # Cast for AMP

            # Initialize gathered input tensor
            quantizer = None
            if fp8 or debug:
                quantizer = input_quantizer
                quantizer.set_usage(rowwise=True, columnwise=False)
            if with_input_all_gather_nccl:  # Perform NCCL all-gather
                inputmat_total, _ = gather_along_first_dim(
                    inputmat,
                    tp_group,
                    quantizer=quantizer,
                )
            elif ub_overlap_ag_fprop:  # Initialize Userbuffers all-gather
                inputmat_total, _ = fill_userbuffers_buffer_for_all_gather(
                    ub_obj,
                    inputmat,
                    quantizer,
                    tp_group,
                )

        else:  # Do not all-gather input tensor
            if fp8 or debug:
224
                if isinstance(inputmat, QuantizedTensorStorage):
225
                    inputmat.update_usage(rowwise_usage=True)
226
                else:
227
228
                    if input_quantizer is None:
                        raise ValueError("Missing quantizer for input tensor")
229
230
231
                    input_quantizer.set_usage(
                        rowwise=True, columnwise=backward_needs_input and not save_original_input
                    )
232
                    inputmat = input_quantizer(inputmat)
233
                    own_quantized_input = True
234
            else:
235
236
                inputmat = cast_if_needed(inp, activation_dtype)  # Cast for AMP
            inputmat_total = inputmat
237
238
239

        if is_cpu_offload_enabled():
            start_offload(inputmat)
240
        nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
241
242
243
        # ------------------------------------------------------
        # Input tensor is ready for GEMM...
        # ------------------------------------------------------
244

245
246
247
        # ------------------------------------------------------
        # Prepare weight tensor
        # ------------------------------------------------------
248
249
        weightmat = weight
        if fp8 or debug:
250
            # Configure quantizer
251
252
            # No need to set the quantizer states if weight is already quantized
            if weight_quantizer is not None and not isinstance(weight, QuantizedTensor):
253
254
255
256
257
258
259
                columnwise_usage = is_grad_enabled and inp.requires_grad
                if not columnwise_usage:
                    columnwise_usage = (
                        is_fp8_activation_recompute_enabled()
                        and not in_fp8_activation_recompute_phase()
                    )
                weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
260
261
262
            elif isinstance(weight, QuantizedTensor):
                # If weight is already quantized, no need to set quantizer states
                weight_quantizer = weight._quantizer
263
            # Get quantized weight
264
265
266
267
268
269
270
271
            update_workspace = is_first_microbatch is None or is_first_microbatch
            weightmat = module.get_weight_workspace(
                tensor=weight,
                quantizer=weight_quantizer,
                cache_name=(None if is_first_microbatch is None else "weight"),
                update_workspace=update_workspace,
                skip_update_flag=skip_fp8_weight_update,
                fsdp_group=fsdp_group,
272
                workspace_dtype=activation_dtype,
273
            )
274
275
            weightmat.update_usage(rowwise_usage=True)

276
        else:
277
278
279
280
            weightmat = cast_if_needed(weightmat, activation_dtype)  # Cast for AMP
        # ------------------------------------------------------
        # Weight tensor is ready for GEMM...
        # ------------------------------------------------------
281
282
283

        # Cast bias to expected dtype
        bias_dtype = activation_dtype
284
        if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32:
285
            # cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
286
287
288
289
290
291
292
293
294
295
            bias_dtype = torch.bfloat16
        bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias

        # Calibrate quantizers if needed
        if not fp8 and fp8_calibration:
            if input_quantizer is not None:
                input_quantizer.calibrate(inputmat_total)
            if weight_quantizer is not None:
                weight_quantizer.calibrate(weight)

296
297
        # Choose whether to use GEMM kernel with split accumulator
        use_split_accumulator = _2X_ACC_FPROP
298
299
300
        if fp8:
            recipe = FP8GlobalStateManager.get_fp8_recipe()
            if hasattr(recipe, "fp8_gemm_fprop"):
301
302
303
304
305
                use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator

        # Configure output quantizer
        if output_quantizer is not None:
            output_quantizer.set_usage(rowwise=True, columnwise=False)
306

307
308
309
310
311
312
313
314
315
316
317
318
319
320
        # Output buffer for Userbuffers reduce-scatter
        reduce_scatter_out = None
        if ub_overlap_rs_fprop:
            out_shape = list(inp.shape)
            out_shape[0] //= tp_world_size
            out_shape[-1] = out_features
            reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device)

        # ------------------------------------------------------
        # Forward GEMM
        # Note: y = x * w^T
        # ------------------------------------------------------
        nvtx_range_push(f"{nvtx_label}.gemm")
        gemm_out, *_, reduce_scatter_out = general_gemm(
321
322
323
324
            weightmat,
            inputmat_total,
            get_workspace(),
            quantization_params=output_quantizer,
325
            out_dtype=activation_dtype,
326
            bias=bias,
327
            use_split_accumulator=use_split_accumulator,
328
329
            ub=ub_obj,
            ub_type=ub_type,
330
            extra_output=reduce_scatter_out,
331
        )
332
        nvtx_range_pop(f"{nvtx_label}.gemm")
333
334
335
336
        # ------------------------------------------------------
        # Finished forward GEMM...
        # ------------------------------------------------------

337
338
339
340
341
342
343
        # Deallocate GEMM input tensor if no longer needed
        # TODO(yuzhongw, tmoon): Figure out why inputmat_total is not automatically
        # deallocated by GC. Manually deallocating is a temporary hack.
        if with_input_all_gather_nccl:
            clear_tensor_data(inputmat_total)
            inputmat_total = None

344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        # ------------------------------------------------------
        # Prepare output tensor
        # Note: Perform tensor-parallel communication
        # ------------------------------------------------------
        out = None
        if ub_overlap_rs_fprop:
            out = reduce_scatter_out
        elif parallel_mode == "row" and tp_size > 1:
            nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
            out = gemm_out
            if sequence_parallel:
                out, _ = reduce_scatter_along_first_dim(out, tp_group)
            elif tensor_parallel:
                if symmetric_ar_type is not None:
                    out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
                else:
                    out, _ = allreduce(out, tp_group)
            nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
        else:
            out = gemm_out
        # ------------------------------------------------------
        # Output tensor is ready to return...
        # ------------------------------------------------------

        # ------------------------------------------------------
        # Cache state for backward pass
        # ------------------------------------------------------
371
372

        if is_grad_enabled:
373
374
375
            if save_original_input:
                inputmat = inp

376
            ctx.weight_quantizer = weight_quantizer
377
378
379
380
381

            ctx.backward_input_needs_gather = (
                weight.requires_grad and parallel_mode == "column" and sequence_parallel
            )

382
383
384
385
            # Discard unneeded data in input tensor
            if (
                backward_needs_input
                and own_quantized_input
386
                and isinstance(inputmat, QuantizedTensorStorage)
387
            ):
388
389
390
                if (
                    ctx.backward_input_needs_gather
                    and weight_quantizer.supports_only_rowwise_all_gather()
391
392
393
394
395
396
397
398
399
                ):
                    # All-gather is not supported with FP8 column-wise data
                    inputmat.update_usage(rowwise_usage=True, columnwise_usage=False)
                else:
                    # Discard row-wise data since it is not needed in backward pass
                    inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)

            # Cached input tensor
            saved_inputmat = None
400
401
            if backward_needs_input:
                saved_inputmat = inputmat
402

403
404
            if cpu_offloading and saved_inputmat is not None:
                mark_activation_offload(saved_inputmat)
405

406
407
            # Scatter intermediate/activation tensors saved for the backward pass
            # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights
408
            nvtx_range_push(f"{nvtx_label}.fsdp_scatter")
409
410
411
            ctx.fsdp_group = fsdp_group
            ctx.fsdp_shapes = _fsdp_scatter_tensors(
                fsdp_group,
412
                saved_inputmat,
413
                weightmat if fp8 and not isinstance(weight, QuantizedTensorStorage) else None,
414
            )
415
            nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
416

417
418
419
420
421
422
423
424
425
426
427
            if cpu_offloading:
                ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")

                if ctx.grad_added_to_main_grad:
                    # If you are passing torch.nn.Parameter through the Torch hooks, you will
                    # get back torch.Tensor. Torch rips off the Parameter wrapper.
                    # You need to preserve the weight object to have all the attributes user
                    # sets for the weights. Because of this, it is not recommended to offload
                    # weights if weights are externally touched outside this module
                    ctx.weight_object = weight

428
            mark_not_offload(weight, weightmat, bias)
429
430
            # TODO(ksivamani): Check memory usage
            tensors_to_save, tensor_objects = prepare_for_saving(
431
                saved_inputmat,
432
                weightmat,
433
                weight,
434
                bias,
435
            )
436
437
            ctx.save_for_backward(*tensors_to_save)
            ctx.tensor_objects = tensor_objects
438

439
440
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
441
            ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
442
443
            ctx.input_quantizer = input_quantizer
            ctx.grad_input_quantizer = grad_input_quantizer
444
445
            ctx.grad_weight_quantizer = grad_weight_quantizer
            ctx.grad_output_quantizer = grad_output_quantizer
446
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
447
            if fuse_wgrad_accumulation and weight.requires_grad:
448
449
450
451
452
453
454
455
                # This check is needed to ensure that main_grad is not created
                # during the forward pass when using MCore FSDP as it creates
                # the main_grad buffer lazily before backprop
                if hasattr(weight, "__fsdp_param__"):
                    # MCore FSDP creates main_grad lazily before backward
                    ctx.main_grad_func = weight.get_main_grad
                else:
                    ctx.main_grad_func = lambda: weight.main_grad
456

457
            ctx.debug = debug
458
            ctx.custom = custom
459
            ctx.cpu_offloading = cpu_offloading
460
            ctx.is_first_microbatch = is_first_microbatch
461
            ctx.use_bias = bias is not None
462
463
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
464
            ctx.inp_shape = inp.shape
465
466
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
467
468
469
470
            ctx.ub_overlap_ag = ub_overlap_ag_dgrad
            ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
            ctx.ub_bulk_dgrad = ub_bulk_dgrad
            ctx.ub_bulk_wgrad = ub_bulk_wgrad
471
            ctx.ub_name = ub_name
472
473
            ctx.tp_size = tp_size
            ctx.requires_dgrad = inp.requires_grad
474
            ctx.requires_wgrad = weight.requires_grad
475
            ctx.reduce_and_update_bwd_fp8_tensors = False
476

477
            ctx.owns_input = saved_inputmat is not inp
478
            if ctx.fp8 and requires_grad(inp, weight, bias):
479
480
481
482
                _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
                ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
                if in_fp8_activation_recompute_phase():
                    FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
483
            ctx.wgrad_store = wgrad_store
484

485
486
487
        # ------------------------------------------------------
        # Cached state for backward pass is ready...
        # ------------------------------------------------------
488

489
        return out
490
491

    @staticmethod
492
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
493
        # pylint: disable=missing-function-docstring
494

495
496
497
498
499
        # NVTX label for profiling
        nvtx_label = "transformer_engine._Linear.backward"
        if ctx.ub_name is not None:
            nvtx_label = f"{nvtx_label}.{ctx.ub_name}"

500
        with torch.cuda.nvtx.range("_Linear_backward"):
501
502
503
504
            saved_tensors = ctx.saved_tensors
            inputmat, weight_fp8, weight, bias = (  # pylint: disable=unbalanced-tuple-unpacking
                restore_from_saved(ctx.tensor_objects, saved_tensors)
            )
505

506
507
508
            # Delete the references to tensor objects once they've been consumed
            # by the `restore_from_saved` method to construct back the actual tensors.
            ctx.tensor_objects = None
509
510
511

            # Since main_grad can be modified inplace, it should not be a part of saved_tensors
            main_grad = (
512
                ctx.main_grad_func()
513
514
515
516
                if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad
                else None
            )

517
518
519
            if ctx.cpu_offloading:
                if ctx.grad_added_to_main_grad:
                    weight = ctx.weight_object
520
521
            if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
                weight.main_grad = main_grad
522

523
524
525
            # Gather intermediate/activation tensors if needed
            # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
            #       shards/unshards the base weights so we don't do it ourselves
526
            nvtx_range_push(f"{nvtx_label}.fsdp_gather")
527
528
529
530
            _fsdp_gather_tensors(
                ctx.fsdp_group,
                ctx.fsdp_shapes,
                inputmat,
531
                weight_fp8,
532
            )
533
            nvtx_range_pop(f"{nvtx_label}.fsdp_gather")
534

535
            # Configure Userbuffers communication (comm+GEMM overlap)
536
            ctx.ub_obj_gradout = None
537
            ub_obj_dgrad = None
538
            ub_obj_wgrad = None
539
540
            ub_type_dgrad = None
            ub_type_wgrad = None
541
            dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
542
            if ctx.ub_overlap_ag:
543
                # Overlap grad_output all-gather with dgrad compute
544
                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
545
546
                ub_obj_dgrad = ctx.ub_obj_gradout
                ub_type_dgrad = tex.CommOverlapType.AG
547
548
            elif ctx.ub_overlap_rs_dgrad:
                # Overlap dgrad reduce-scatter with dgrad compute
549
                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
550
551
                ub_obj_dgrad = ctx.ub_obj_gradout
                ub_type_dgrad = tex.CommOverlapType.RS
552
553
554
            else:
                if ctx.ub_bulk_dgrad:
                    # Overlap inputmat all-gather with dgrad compute
555
                    ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
556
557
                    ub_obj_dgrad = ctx.ub_obj_gradout
                    ub_type_dgrad = tex.CommOverlapType.AG
558
559
                if ctx.ub_bulk_wgrad:
                    # Overlap dgrad reduce-scatter with wgrad compute
560
                    ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
561
                    ub_type_wgrad = tex.CommOverlapType.RS
562
563
564
565
566
567
568
569

            # --------------------------------------------------
            # Prepare grad output tensor
            # Note: Cast to expected dtype and perform tensor-parallel communication
            # --------------------------------------------------

            # Unmodified grad output tensor
            grad_output_arg = grad_output
570

571
572
573
574
            # Configure quantizer for grad output tensor
            # Note: dgrad GEMM requires row-wise usage, wgrad GEMM
            # requires column-wise usage
            if ctx.grad_output_quantizer is not None:
575
576
577
578
579
580
581
                quantizer = ctx.grad_output_quantizer
                quantizer.set_usage(rowwise=True, columnwise=True)
                if ctx.ub_overlap_ag:
                    # Userbuffers only supports communication for one
                    # tensor usage at a time. Configure quantizer with
                    # usage for only dgrad GEMM.
                    quantizer.set_usage(columnwise=False)
582

583
584
585
586
587
588
589
590
591
592
593
594
595
            # Adjust the quantization direction approach depending
            # on whether wgrad calculations will be performed.
            # NOTE: If requires_dgrad is False, disabling `rowwise` quantization and keeping `columnwise` quantization
            #       results in `Assertion failed: output_tensor->has_data(). Quantizing in only the columnwise direction not supported yet!`
            # NOTE: For `ctx.bias is True`, selected quantize kernel errors with
            #       `cast_kernels.cuh:1322 in function fp8_quantize_arch_l_100: Not implemented scaling mode or fusion: NVTE_DELAYED_TENSOR_SCALING or IS_DBIAS=true on GPU with compute capability < 10.0.`
            if (
                not ctx.use_bias
                and not ctx.requires_wgrad
                and ctx.grad_output_quantizer is not None
            ):
                ctx.grad_output_quantizer.set_usage(columnwise=False)

596
            # Prepare grad output tensor
597
            nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
598
599
600
601
            (
                grad_output,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
602
603
604
605
                ctx,
                grad_output,
                ctx.parallel_mode == "row",
                ctx.grad_output_quantizer,
606
            )
607
            nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
608

609
610
611
612
613
614
615
616
617
618
            # --------------------------------------------------
            # Grad output tensor is ready for computing grad input...
            # --------------------------------------------------

            # --------------------------------------------------
            # Prepare input tensor
            # Note: Input tensor is needed for wgrad GEMM.
            # Tensor-parallel communication is overlapped with dgrad
            # GEMM.
            # --------------------------------------------------
619
            inputmat_total = None
620
            inputmat_total_work = None
621
622
            if ctx.requires_wgrad:
                if ctx.fp8 or ctx.debug:
623
                    if isinstance(inputmat, QuantizedTensorStorage):
624
625
                        # Input tensor is already quantized
                        pass
626
                    elif ctx.debug or ctx.custom:
627
628
629
630
                        # Debug quantizer will be applied immediately before wgrad GEMM
                        pass
                    else:
                        # Quantize input tensor
631
                        quantizer = ctx.input_quantizer
632
                        if quantizer.supports_only_rowwise_all_gather():
633
                            # All-gather is not supported with FP8 column-wise data
634
635
636
637
                            quantizer.set_usage(
                                rowwise=True,
                                columnwise=not ctx.backward_input_needs_gather,
                            )
638
                        else:
639
                            quantizer.set_usage(rowwise=False, columnwise=True)
640
641
                        inputmat = quantizer(inputmat)
                else:
642
                    if isinstance(inputmat, QuantizedTensorStorage):
643
644
645
                        inputmat = inputmat.dequantize(dtype=ctx.activation_dtype)
                    else:
                        inputmat = cast_if_needed(inputmat, ctx.activation_dtype)
646
            if ctx.backward_input_needs_gather:
647
                quantizer = None
648
                if ctx.fp8 or ctx.debug:
649
                    quantizer = ctx.input_quantizer
650
                    if quantizer.supports_only_rowwise_all_gather():
651
652
653
654
655
                        # If data is in FP8, we compute FP8 transposes manually
                        quantizer.set_usage(rowwise=True, columnwise=False)
                    else:
                        # wgrad GEMM requires input with column-wise usage
                        quantizer.set_usage(rowwise=False, columnwise=True)
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
                if ctx.ub_bulk_dgrad:
                    inputmat_total, _ = fill_userbuffers_buffer_for_all_gather(
                        ub_obj_dgrad,
                        inputmat,
                        quantizer,
                        ctx.tp_group,
                    )
                else:
                    nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
                    inputmat_total, inputmat_total_work = gather_along_first_dim(
                        inputmat,
                        ctx.tp_group,
                        async_op=True,
                        quantizer=quantizer,
                    )
                    nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
672
673
            else:
                inputmat_total = inputmat
674
675
676
            # --------------------------------------------------
            # Input tensor is ready for computing grad weight...
            # --------------------------------------------------
677

678
            # --------------------------------------------------
679
            # Compute grad input tensor
680
681
            # --------------------------------------------------

682
683
            dgrad = None
            dgrad_work = None
684
            if ctx.requires_dgrad:
685

686
                # Make sure required data is available
687
                if isinstance(grad_output, QuantizedTensorStorage):
688
                    grad_output.update_usage(rowwise_usage=True)
689
690
691
                if ctx.weight_quantizer is not None and isinstance(
                    weight_fp8, QuantizedTensorStorage
                ):
692
693
694
695
                    weight_fp8.update_usage(columnwise_usage=True)

                # Choose whether to use GEMM kernel with split accumulator
                use_split_accumulator = _2X_ACC_DGRAD
696
697
698
                if ctx.fp8:
                    recipe = ctx.fp8_recipe
                    if hasattr(recipe, "fp8_gemm_dgrad"):
699
                        use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
700

701
702
703
704
705
706
707
708
709
710
                # Update grad input quantizer
                if ctx.grad_input_quantizer is not None:
                    ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)

                # Output buffers for Userbuffers reduce-scatter
                gemm_out = None
                reduce_scatter_out = None
                if ctx.ub_overlap_rs_dgrad:
                    reduce_scatter_out = torch.empty(
                        dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device
711
                    )
712
713
                elif ctx.ub_bulk_wgrad:
                    gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False)
714

715
716
                # dgrad GEMM
                # Note: dx = dy * w
717

718
719
                nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
                gemm_out, *_, reduce_scatter_out = general_gemm(
720
721
722
723
724
725
                    weight_fp8,
                    grad_output,
                    get_workspace(),
                    layout="NN",
                    grad=True,
                    quantization_params=ctx.grad_input_quantizer,
726
                    out=gemm_out,
727
                    out_dtype=ctx.activation_dtype,
728
                    use_split_accumulator=use_split_accumulator,
729
730
                    ub=ub_obj_dgrad,
                    ub_type=ub_type_dgrad,
731
                    extra_output=reduce_scatter_out,
732
733
                    bulk_overlap=ctx.ub_bulk_dgrad,
                )
734
                nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
735

736
737
                # Prepare grad input tensor
                # Note: Perform tensor-parallel communication
738
                if ctx.ub_overlap_rs_dgrad:
739
740
741
742
                    dgrad = reduce_scatter_out
                elif ctx.ub_bulk_wgrad:
                    dgrad = ub_obj_wgrad.get_buffer(local_chunk=True)
                elif ctx.parallel_mode == "column" and ctx.tp_size > 1:
743
                    nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
744
                    dgrad = gemm_out
745
746
747
748
749
                    if ctx.sequence_parallel:
                        dgrad, dgrad_work = reduce_scatter_along_first_dim(
                            dgrad,
                            ctx.tp_group,
                            async_op=True,
750
                        )
751
                    else:
752
                        dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
753
                    nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
754
755
756
757
758
759
760
761
762
763
                else:
                    dgrad = gemm_out

            # --------------------------------------------------
            # Grad input tensor has been computed...
            # --------------------------------------------------

            # --------------------------------------------------
            # Compute grad weight
            # --------------------------------------------------
764

765
766
            wgrad = None
            if ctx.requires_wgrad:
767

768
769
770
                # Prepare input tensor
                # Note: Synchronize tensor-parallel communication and
                # make sure required data is available
771
772
773
                if inputmat_total_work is not None:
                    inputmat_total_work.wait()
                    inputmat_total_work = None
774
                if ctx.fp8 or ctx.debug:
775
                    if isinstance(inputmat_total, QuantizedTensorStorage):
776
777
778
779
780
781
782
783
784
                        inputmat_total.update_usage(columnwise_usage=True)
                    else:
                        ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
                        inputmat_total = ctx.input_quantizer(inputmat_total)

                # Prepare grad output tensor
                # Note: Synchronize tensor-parallel communication and
                # make sure required data is available
                if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
785
                    # UB does not support pipelined overlapping grad output
786
787
788
                    # all-gather with wgrad GEMM. Also, we can't
                    # convert row-scaled MXFP8 to column-scaled, so we
                    # can't reuse the grad output that was gathered
789
                    # for the dgrad GEMM. We work around by explicitly
790
791
792
793
794
795
                    # overlapping the AG operation with the dgrad GEMM.

                    # Get the communication stream from the dgrad GEMM to use for the AG
                    dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream()

                    # This object is separate from the ub_obj_wgrad object which is passed to the GEMM
796
                    ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
797

798
                    ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
799
800
801
802
803
804
805

                    # We use the send stream to copy into the userbuffers.
                    # This is the same stream that we will use to access the data in the AG,
                    # so we dont need to add any syncs yet.
                    with torch.cuda.stream(dgrad_send_stream):
                        grad_output, _ = fill_userbuffers_buffer_for_all_gather(
                            ub_obj_overlap_wgrad,
806
                            grad_output_arg,
807
                            ctx.grad_output_quantizer,
808
809
                            ctx.tp_group,
                        )
810
811
812
813
814

                    # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm
                    tex.bulk_overlap_ag_with_external_gemm(
                        ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream
                    )
815

816
                if ctx.fp8 or ctx.debug:
817
                    if isinstance(grad_output, QuantizedTensorStorage):
818
819
820
821
                        grad_output.update_usage(columnwise_usage=True)
                    else:
                        ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
                        grad_output = ctx.grad_output_quantizer(grad_output)
822

823
824
825
826
827
828
829
                # Figure out whether to use split accumulator
                use_split_accumulator = _2X_ACC_WGRAD
                if ctx.fp8:
                    recipe = ctx.fp8_recipe
                    if hasattr(recipe, "fp8_gemm_wgrad"):
                        use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator

830
831
832
833
834
835
836
837
838
                # Figure out whether to output wgrad GEMM directly into main grad
                if ctx.is_first_microbatch is not None:
                    accumulate_wgrad_into_param_main_grad = (
                        ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                    )
                else:
                    accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation

                # Output buffer for overlapping FP8 grad input
839
                # reduce-scatter with wgrad GEMM
840
                reduce_scatter_out = None
841
                if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
842
843
                    reduce_scatter_out = torch.empty(
                        dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device
844
845
                    )

846
847
848
849
                # Arguments to include in wgrad GEMM closure
                wgrad_gemm_kwargs = {
                    "workspace": get_workspace(),
                    "out_dtype": (
850
851
                        main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
                    ),
852
                    "quantization_params": ctx.grad_weight_quantizer,
853
854
855
856
857
                    "accumulate": (
                        accumulate_wgrad_into_param_main_grad
                        if not getattr(weight, "overwrite_main_grad", False)
                        else False
                    ),
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
                    "layout": "NT",
                    "out": main_grad if ctx.fuse_wgrad_accumulation else None,
                    "bias": (bias if (grad_bias is None and not ctx.fp8) else None),
                    "use_split_accumulator": use_split_accumulator,
                    "grad": True,
                    "ub": ub_obj_wgrad,
                    "ub_type": ub_type_wgrad,
                    "extra_output": reduce_scatter_out,
                    "bulk_overlap": ctx.ub_bulk_wgrad,
                }

                def wgrad_gemm(
                    x: torch.Tensor,
                    dy: torch.Tensor,
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
                    """Perform wgrad GEMM: dw = dy^T * x

                    May be fused with bgrad computation.

                    May be called outside of this function to enable
                    some advanced communication/compute overlapping.

                    """
                    nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
                    dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs)
                    nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
                    return dw, db

                # Choose whether to call wgrad GEMM now or delay
887
                if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
888
889
890
891
892
893
894
895
896
897
898
                    if (
                        wgrad_gemm_kwargs["ub"] is not None
                        or wgrad_gemm_kwargs["ub_type"] is not None
                        or wgrad_gemm_kwargs["extra_output"] is not None
                        or wgrad_gemm_kwargs["bulk_overlap"]
                    ):
                        raise NotImplementedError(
                            "Delayed weight grad computation is not supported "
                            "with Userbuffers (tensor-parallel communication overlapping)"
                        )
                    ctx.wgrad_store.put([inputmat_total, grad_output], wgrad_gemm)
899
900
                else:

901
902
903
904
                    # Call wgrad GEMM now
                    wgrad, grad_bias_ = wgrad_gemm(inputmat_total, grad_output)

                    # Update grad bias if needed
905
906
907
908
                    if grad_bias is None:
                        grad_bias = grad_bias_
                    del grad_bias_

909
                    # Deallocate tensors if permitted
910
                    if ctx.owns_input:
911
912
913
914
                        # Input tensor is internal
                        clear_tensor_data(inputmat_total)
                    elif ctx.backward_input_needs_gather:
                        # Gathered input tensor is internal
915
                        clear_tensor_data(inputmat_total)
916
917
918
                    if ctx.parallel_mode == "row" and ctx.sequence_parallel:
                        # Gathered grad output tensor is internal
                        clear_tensor_data(grad_output)
919

920
                # Update grad input if overlapping reduce-scatter with wgrad GEMM
921
922
                if ctx.ub_bulk_wgrad:
                    if ub_obj_wgrad.is_fp8_ubuf():
923
                        dgrad = reduce_scatter_out
924
                    else:
925
926
927
928
929
                        dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone()

            # --------------------------------------------------
            # Grad weight has been computed...
            # --------------------------------------------------
930

931
            # Don't return grad bias if not needed
932
933
934
            if not ctx.use_bias:
                grad_bias = None

935
            # Make sure all tensor-parallel communication is finished
936
937
938
939
940
941
942
943
            if inputmat_total_work is not None:
                inputmat_total_work.wait()
                inputmat_total_work = None
            if dgrad_work is not None:
                dgrad_work.wait()
                dgrad_work = None

        if ctx.requires_wgrad:
944
            # Handle custom DDP from mcore.
945
946
947
948
949
            if (
                ctx.fuse_wgrad_accumulation
                and weight is not None
                and hasattr(weight, "grad_added_to_main_grad")
            ):
950
                weight.grad_added_to_main_grad = True
951
                if getattr(weight, "zero_out_wgrad", False):
952
953
954
955
                    wgrad = get_dummy_wgrad(
                        list(weight.main_grad.shape),
                        weight.dtype,
                        zero=True,
956
                    )
957
                else:
958
959
960
                    wgrad = get_dummy_wgrad(
                        list(weight.main_grad.shape),
                        weight.dtype,
961
                    )
962
963
964
965
            elif ctx.fuse_wgrad_accumulation:
                wgrad = None
        else:
            wgrad = None
966

967
        # Update FP8 scaling factors if needed
968
        if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
969
            nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
970
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
971
            nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
972

973
        # Scatter fp8 weight buffers
974
        if ctx.fp8 and not isinstance(weight, QuantizedTensorStorage):
975
            _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
976
        return (
977
            wgrad,
978
979
            dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
            grad_bias,
980
981
982
            None,  # is_first_microbatch
            None,  # fp8
            None,  # fp8_calibration
983
            None,  # wgrad_store
984
985
986
987
            None,  # input_quantizer
            None,  # weight_quantizer
            None,  # output_quantizer
            None,  # grad_input_quantizer
988
989
            None,  # grad_weight_quantizer
            None,  # grad_output_quantizer
990
991
992
993
994
995
996
997
998
            None,  # fuse_wgrad_accumulation
            None,  # cpu_offloading
            None,  # tp_group
            None,  # tp_size
            None,  # sequence_parallel
            None,  # tensor_parallel
            None,  # activation_dtype
            None,  # parallel_mode
            None,  # is_grad_enabled
999
1000
1001
1002
1003
1004
            None,  # ub_overlap_rs_fprop
            None,  # ub_overlap_ag_dgrad
            None,  # ub_overlap_ag_fprop
            None,  # ub_overlap_rs_dgrad
            None,  # ub_bulk_dgrad
            None,  # ub_bulk_wgrad
1005
            None,  # ub_name
1006
            None,  # fp8_output
1007
            None,  # fsdp_group
1008
1009
            None,  # module
            None,  # skip_fp8_weight_update
1010
            None,  # symmetric_ar_type
1011
            None,  # save_original_input
1012
            None,  # debug
1013
1014
1015
1016
        )


class Linear(TransformerEngineBaseModule):
1017
    """Applies a linear transformation to the incoming data :math:`y = xA^T + b`
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031

    On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
1032
    get_rng_state_tracker : Callable, default = `None`
1033
                 used to get the random number generator state tracker for initializing weights.
1034
1035
    rng_tracker_name : str, default = `None`
                 the param passed to get_rng_state_tracker to get the specific rng tracker.
cyanguwa's avatar
cyanguwa committed
1036
    parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
1037
1038
1039
1040
1041
1042
1043
                      Configuration for splitting the weight and bias tensors along dim 0 into
                      multiple PyTorch parameters. If a list or tuple of strings is provided,
                      they are used to make the names of equally-sized parameters. If a dict
                      (preferably an OrderedDict) is provided, the keys are used as names and
                      values as split sizes along dim 0. The resulting parameters will have
                      names that end in `_weight` or `_bias`, so trailing underscores are
                      stripped from any provided names.
1044
    device : Union[torch.device, str], default = "cuda"
1045
          The device on which the parameters of the model will be allocated. It is the user's
1046
1047
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
1048
1049
    name: str, default = `None`
        name of the module, currently used for debugging purposes.
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
1063
    parallel_mode : {None, 'column', 'row'}, default = `None`
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
1075
1076
1077
                             size to accumulate gradients in. This argument along with
                             weight tensor having attribute 'overwrite_main_grad' set to True
                             will overwrite `main_grad` instead of accumulating.
1078
1079
1080
1081
1082
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
1083
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
1084
1085
1086
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
1087
1088
1089
1090
    delay_wgrad_compute : bool, default = `False`
                         Whether or not to delay weight gradient computation. If set to `True`,
                         it's the user's responsibility to call `module.backward_dw` to compute
                         weight gradients.
1091
1092
1093
1094
1095
    symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
                   Type of symmetric memory all-reduce to use during the forward pass.
                   This can help in latency bound communication situations.
                   Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
                   is used.
1096
1097
1098
1099
1100
    save_original_input : bool, default = `False`
                       If set to `True`, always saves the original input tensor rather than the
                       cast tensor. In some scenarios, the input tensor is used by multiple modules,
                       and saving the original input tensor may reduce the memory usage.
                       Cannot work with FP8 DelayedScaling recipe.
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
1112
        rng_tracker_name: Optional[str] = None,
1113
1114
1115
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
1116
        params_dtype: Optional[torch.dtype] = None,
1117
        parallel_mode: Optional[str] = None,
cyanguwa's avatar
cyanguwa committed
1118
        parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
1119
        device: Union[torch.device, str] = "cuda",
1120
        ub_overlap_ag: bool = False,
1121
        ub_overlap_rs: bool = False,
1122
        ub_overlap_rs_dgrad: bool = False,
1123
1124
        ub_bulk_dgrad: bool = False,
        ub_bulk_wgrad: bool = False,
1125
        ub_name: Optional[str] = None,
1126
        delay_wgrad_compute: bool = False,
1127
        symmetric_ar_type: Optional[str] = None,
1128
        save_original_input: bool = False,
1129
        name: Optional[str] = None,
1130
1131
    ) -> None:
        super().__init__()
1132
1133

        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
1134
1135
1136
1137
1138
1139
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
        self.apply_bias = bias and not return_bias
1140
1141
        self.get_rng_state_tracker = get_rng_state_tracker
        self.rng_tracker_name = rng_tracker_name
1142
        self.symmetric_ar_type = symmetric_ar_type
1143
        self.save_original_input = save_original_input
1144
1145
        self.name = name

1146
1147
        self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)

1148
1149
        if device == "meta":
            assert parameters_split is None, "Cannot split module parameters on 'meta' device."
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

1171
        # Column parallel TP overlap options
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
        self.ub_overlap_ag_fprop = (
            self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_ag
        )
        self.ub_overlap_rs_dgrad = (
            self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_rs_dgrad
        )
        self.ub_bulk_dgrad = (
            self.parallel_mode == "column"
            and self.sequence_parallel
            and ub_bulk_dgrad
            and not self.ub_overlap_rs_dgrad
        )
        self.ub_bulk_wgrad = (
            self.parallel_mode == "column"
            and self.sequence_parallel
            and ub_bulk_wgrad
            and not self.ub_overlap_rs_dgrad
        )
1190
1191

        # Row parallel TP overlap options
1192
1193
1194
1195
1196
1197
        self.ub_overlap_rs_fprop = (
            self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_rs
        )
        self.ub_overlap_ag_dgrad = (
            self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_ag
        )
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211

        if any(
            [
                self.ub_overlap_rs_fprop,
                self.ub_overlap_ag_dgrad,
                self.ub_overlap_ag_fprop,
                self.ub_overlap_rs_dgrad,
                self.ub_bulk_dgrad,
                self.ub_bulk_wgrad,
            ]
        ):
            assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized."
        self.ub_name = ub_name

1212
1213
1214
1215
1216
1217
1218
        if self.symmetric_ar_type is not None:
            assert torch_version() >= (
                2,
                7,
                0,
            ), "Torch version must be at least 2.7 to use symmetric memory"

1219
1220
1221
        # Initialize params in FP8
        with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()

1222
1223
1224
1225
1226
1227
1228
1229
        # Contiguous buffers for params
        weight_tensor = torch.empty(
            self.out_features,
            self.in_features,
            device=device,
            dtype=params_dtype,
        )
        bias_tensor = None
1230
        if self.use_bias:
1231
1232
1233
1234
1235
            bias_tensor = torch.empty(
                self.out_features,
                device=device,
                dtype=params_dtype,
            )
1236

1237
1238
1239
1240
        # Configure parameter splits
        self.weight_names = []
        self.bias_names = []
        self.parameter_split_sizes = []
1241
        if parameters_split is None:
1242
1243
1244
1245
1246
1247
            # Split into a single parameter by default
            self.weight_names = ["weight"]
            self.bias_names = ["bias"]
            self.parameter_split_sizes = [out_features]
        elif not parameters_split:
            raise ValueError("Cannot split weight buffer into 0 parameters")
cyanguwa's avatar
cyanguwa committed
1248
        elif isinstance(parameters_split, dict):
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
            # Split parameters with provided sizes
            for name, split_size in parameters_split.items():
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
        elif all(isinstance(name, str) for name in parameters_split):
            # Split parameters evenly
            split_size = out_features // len(parameters_split)
            for name in parameters_split:
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
cyanguwa's avatar
cyanguwa committed
1261
        else:
1262
            raise TypeError("Invalid configuration for parameters split")
1263

1264
1265
1266
1267
1268
1269
        # Make sure parameter splits are valid
        if sum(self.parameter_split_sizes) != out_features:
            raise ValueError(
                f"Trying to split weight buffer ({out_features=}) "
                f"with split sizes {self.parameter_split_sizes}"
            )
1270

1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
        # Adjust parameter splits for tensor-parallel distribution
        if self.parallel_mode == "column":
            for i, size in enumerate(self.parameter_split_sizes):
                if size % self.tp_size != 0:
                    raise RuntimeError(
                        f"Attempting to distribute a parameter with out_features={size} "
                        f"between {self.tp_size} tensor-parallel processes"
                    )
                self.parameter_split_sizes[i] = size // self.tp_size

1281
1282
1283
1284
1285
        # Construct weight parameters
        # Note: Register weights together so that they are adjacent to
        # each other in Linear.parameters(). This makes it more likely
        # that they will stay contiguous if the weights are
        # manipulated externally, e.g. by FSDP.
1286
1287
1288
1289
1290
1291
1292
1293
        offset = 0
        for i, split_size in enumerate(self.parameter_split_sizes):
            split_start = offset
            offset += split_size
            split_end = offset

            # Check if parameters are subviews of buffers
            is_subview = (split_start, split_end) != (0, self.out_features)
1294
            if is_subview and with_fp8_params:
1295
1296
1297
                raise RuntimeError(
                    "Splitting QuantizedTensor into multiple params is not supported"
                )
1298

1299
            # Construct weight parameter
1300
1301
1302
1303
1304
1305
1306
            self.register_parameter(
                self.weight_names[i],
                torch.nn.Parameter(weight_tensor[split_start:split_end]),
                init_fn=init_method,
                get_rng_state_tracker=get_rng_state_tracker,
                fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
            )
1307

1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
        # Construct bias parameters if needed
        if self.use_bias:
            offset = 0
            for i, split_size in enumerate(self.parameter_split_sizes):
                split_start = offset
                offset += split_size
                split_end = offset
                self.register_parameter(
                    self.bias_names[i],
                    torch.nn.Parameter(bias_tensor[split_start:split_end]),
                    init_fn=init_method_constant(0.0),
                )
        else:
            for name in self.bias_names:
                bias = torch.Tensor().to(dtype=params_dtype, device=device)
                setattr(self, name, bias)
cyanguwa's avatar
cyanguwa committed
1324

1325
        if with_fp8_params:
1326
1327
            self.init_fp8_metadata()

1328
        self.reset_parameters(defer_init=device == "meta")
1329

1330
1331
1332
1333
1334
1335
1336
        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.apply_bias:
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

1337
1338
1339
1340
1341
        if self.wgrad_store.delay_wgrad_compute():
            for name, param in self.named_parameters():
                if name in self.weight_names or name in self.bias_names:
                    param.skip_backward_post_hook = True

1342
1343
1344
1345
1346
1347
1348
1349
    def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
        """Init scales and amaxes for fwd | bwd."""
        super().set_meta_tensor(fwd, recipe)

        # customize quantizers based on each recipe & layer configs
        recipe = FP8GlobalStateManager.get_fp8_recipe()
        if recipe.float8_current_scaling():
            self._customize_quantizers_float8_current_scaling(fwd, recipe)
1350
1351
        elif recipe.float8_block_scaling():
            self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
1352
1353
        elif recipe.nvfp4():
            self._customize_quantizers_nvfp4(fwd, recipe)
1354
1355
        # elif for other recipes (mxfp8, etc.)

1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
    def reset_parameters(self, defer_init=False):
        super().reset_parameters(defer_init=defer_init)

        if not defer_init:
            # Set parallelism attributes for linear weights
            for weight in self.weight_names:
                set_tensor_model_parallel_attributes(
                    tensor=getattr(self, weight),
                    is_parallel=True,
                    dim=1 if self.parallel_mode == "row" else 0,
                    stride=1,
                )

            # Set parallelism attributes for linear biases
            if self.use_bias:
                for bias in self.bias_names:
                    if self.parallel_mode == "row":
                        setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel)
                    elif self.parallel_mode == "column":
                        set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)

1377
    @no_torch_dynamo()
1378
1379
1380
1381
    def forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Optional[bool] = None,
1382
        fp8_output: Optional[bool] = False,
1383
        fp8_grad: Optional[bool] = False,
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """
1406
1407
1408
        if is_in_onnx_export_mode():
            return self.onnx_forward(inp, fp8_output)

1409
        debug = self.is_debug_iter()
1410

1411
1412
1413
1414
        if FP8GlobalStateManager.fp8_graph_capturing():
            skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
        else:
            skip_fp8_weight_update = None
1415
1416
1417
        if skip_fp8_weight_update is not None:
            is_first_microbatch = False

1418
        if self.ub_overlap_rs_fprop:
1419
1420
1421
            if get_ub(
                self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()
            ).is_fp8_ubuf():
1422
1423
                fp8_output = True
        if self.ub_overlap_rs_dgrad:
1424
1425
1426
            if get_ub(
                self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()
            ).is_fp8_ubuf():
1427
1428
                fp8_grad = True

1429
1430
1431
        with torch.cuda.device(
            getattr(self, list(self.named_parameters())[0][0]).device
        ), self.prepare_forward(
1432
            inp,
1433
            allow_non_contiguous=isinstance(inp, QuantizedTensor),
1434
        ) as inp:
1435

1436
            weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
1437

1438
1439
1440
1441
1442
            quantizers = (
                self._get_quantizers(fp8_output, fp8_grad)
                if not debug
                else self._get_debug_quantizers(fp8_output, fp8_grad)
            )
1443
            if debug:
1444
                if self.no_debug_features_active(quantizers):
1445
                    debug = False
1446
                    quantizers = self._get_quantizers(fp8_output, fp8_grad)
1447

1448
1449
1450
1451
1452
            (
                input_quantizer,
                weight_quantizer,
                output_quantizer,
                grad_input_quantizer,
1453
1454
1455
                grad_weight_quantizer,
                grad_output_quantizer,
            ) = quantizers
1456

1457
1458
1459
1460
1461
1462
1463
1464
1465
            if torch.is_grad_enabled():
                linear_fn = _Linear.apply
                args = []
            else:
                linear_fn = _Linear.forward
                args = [None]
            args += (
                weight_tensor,
                inp,
1466
                bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
1467
1468
1469
                is_first_microbatch,
                self.fp8,
                self.fp8_calibration,
1470
                self.wgrad_store,
1471
1472
1473
1474
                input_quantizer,
                weight_quantizer,
                output_quantizer,
                grad_input_quantizer,
1475
1476
                grad_weight_quantizer,
                grad_output_quantizer,
1477
                self.fuse_wgrad_accumulation,
1478
                is_cpu_offload_enabled(),
1479
1480
1481
1482
1483
1484
1485
                self.tp_group,
                self.tp_size,
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
                torch.is_grad_enabled(),
1486
1487
1488
1489
1490
1491
                self.ub_overlap_rs_fprop,
                self.ub_overlap_ag_dgrad,
                self.ub_overlap_ag_fprop,
                self.ub_overlap_rs_dgrad,
                self.ub_bulk_dgrad,
                self.ub_bulk_wgrad,
1492
                self.ub_name,
1493
                fp8_output,
1494
                self.fsdp_group,
1495
1496
                self,
                skip_fp8_weight_update,
1497
                self.symmetric_ar_type,
1498
                self.save_original_input,
1499
                debug,
1500
1501
1502
1503
1504
1505
1506
1507
            )
            out = linear_fn(*args)
        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        return out
1508
1509
1510

    def _get_quantizers(self, fp8_output, fp8_grad):
        if not self.fp8:
1511
            return [None] * 6
1512
        grad_input_quantizer = None
1513
        grad_weight_quantizer = None
1514
1515
1516
        grad_output_quantizer = None
        output_quantizer = None
        input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
1517
        input_quantizer.internal = True
1518
        (weight_quantizer,) = self._get_weight_quantizers()
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
        if fp8_output:
            output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
        if torch.is_grad_enabled():
            grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
            grad_output_quantizer.internal = True
            if fp8_grad:
                grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
        return (
            input_quantizer,
            weight_quantizer,
            output_quantizer,
            grad_input_quantizer,
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
            grad_weight_quantizer,
            grad_output_quantizer,
        )

    def _get_debug_quantizers(self, fp8_output, fp8_grad):
        original_quantizers = self._get_quantizers(fp8_output, fp8_grad)
        assert TEDebugState.debug_enabled
        from ...debug.pytorch.debug_quantization import DebugQuantizer

        names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
        return tuple(
            DebugQuantizer(self.name, name, q, self.tp_group)
            for name, q in zip(names, original_quantizers)
1544
        )
1545

1546
    def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
        """Get the weight tensors of the module."""
        unfused_weights = [getattr(self, name) for name in self.weight_names]
        if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
            if self.fp8:
                if len(unfused_weights) != 1:
                    raise RuntimeError(
                        "Splitting QuantizedTensor into multiple params is not supported"
                    )
            else:
                warnings.warn(
                    "You are using quantized weights without quantized compute. "
                    "Please make sure this is intentional."
                )
                unfused_weights = [w.dequantize() for w in unfused_weights]
        return unfused_weights

    def _get_weight_and_bias_tensors(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        # Get concatenated weight and bias tensors
        unfused_weights = self._get_weight_tensors()
        if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
            if self.fp8:
                if len(unfused_weights) != 1:
                    raise RuntimeError(
                        "Splitting QuantizedTensor into multiple params is not supported"
                    )
            else:
                warnings.warn(
                    "You are using quantized weights without quantized compute. "
                    "Please make sure this is intentional."
                )
                unfused_weights = [w.dequantize() for w in unfused_weights]

        weight_tensor = noop_cat(unfused_weights)
        if self.use_bias:
            bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
        else:
            bias_tensor = None

        return weight_tensor, bias_tensor

    def onnx_forward(
        self,
        inp: torch.Tensor,
        fp8_output: bool,
    ) -> torch.Tensor:
        """
        ONNX-compatible version of the forward function that provides numerical equivalence
        while only using operations that have defined ONNX symbolic translations.
        This simplified implementation is designed specifically for inference scenarios.
        """
        from ..export import onnx_gemm

        assert_warmed_up(self)
        assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export."
        weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
        (
            input_quantizer,
            weight_quantizer,
            output_quantizer,
            *_,
        ) = self._get_quantizers(fp8_output, False)
        inp_dtype = inp.dtype

        if input_quantizer is not None:
            inp_q = input_quantizer.onnx_quantize(inp)
            inp = input_quantizer.onnx_dequantize(inp_q)
            inp = inp.to(inp_dtype)

        if weight_quantizer is not None:
            weight_q = weight_quantizer.onnx_quantize(weight_tensor)
            weight_tensor = weight_quantizer.onnx_dequantize(weight_q)
        if bias_tensor is not None:
            bias_tensor = bias_tensor.to(inp_dtype)
        weight_tensor = weight_tensor.to(inp_dtype)

        if self.apply_bias:
            output = onnx_gemm(weight_tensor, inp, bias_tensor)
        else:
            output = onnx_gemm(weight_tensor, inp, None)

        if output_quantizer is not None:
            raise NotImplementedError("ONNX export of quantized output is not supported")

        if self.return_bias:
            return output, bias_tensor

        return output

1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
    def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
        """Customize quantizers based on current scaling recipe + linear."""
        assert (
            recipe.float8_current_scaling()
        ), "current scaling recipe quantizer customization here"
        if fwd:
            # set configs about amax epsilon and power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_INPUT
            ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_INPUT
            ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon
            # also set weight quantizer with same amax_epsilon & power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_WEIGHT
            ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_WEIGHT
            ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
            # paralle related
            if self.sequence_parallel and self.parallel_mode == "column":
                # customize input_quantizer with amax reduction TP group
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].with_amax_reduction = True
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].amax_reduction_group = self.tp_group
        else:
            # set grad_output_quantizer with amax epsilon and power_2_scale
            self.quantizers["scaling_bwd"][
                tex.FP8BwdTensors.GRAD_OUTPUT1
            ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
            self.quantizers["scaling_bwd"][
                tex.FP8BwdTensors.GRAD_OUTPUT1
            ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
            # parallel related
            if self.sequence_parallel and self.parallel_mode == "row":
                # customize grad_output_quantizer with amax reduction TP group
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].with_amax_reduction = True
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].amax_reduction_group = self.tp_group
1681

1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
    def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None:
        """Customize quantizers based on current scaling recipe + linear."""
        assert recipe.nvfp4(), "Incorrect recipe."
        if fwd:
            if self.sequence_parallel and self.parallel_mode == "column":
                # customize input_quantizer with amax reduction TP group
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].with_amax_reduction = True
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].amax_reduction_group = self.tp_group
        else:
            if self.sequence_parallel and self.parallel_mode == "row":
                # customize grad_output_quantizer with amax reduction TP group
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].with_amax_reduction = True
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].amax_reduction_group = self.tp_group

1704
1705
    def _get_weight_quantizers(self) -> List[Quantizer]:
        """Get the weight quantizers of the module."""
1706
        if not self.fp8 and not self.fp8_calibration:
1707
1708
1709
1710
            return [None]
        weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
        weight_quantizer.internal = True
        return [weight_quantizer]
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729

    def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
        """Customize quantizers based on blockwise scaling recipe + linear."""
        assert (
            recipe.float8_block_scaling()
        ), "blockwise scaling recipe quantizer customization here"

        if fwd:
            if self.sequence_parallel and self.parallel_mode == "column":
                # set compact for inp tensor X
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].all_gather_usage = True
        else:
            if self.sequence_parallel and self.parallel_mode == "row":
                # set compact for grad_output tensor dY
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].all_gather_usage = True