linear.py 75.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Linear API"""
6
from typing import Callable, Dict, Optional, Tuple, Union, List
7
8
from functools import reduce
from operator import mul as multiply_op
9
import warnings
10
import os
11
12
13

import torch

14
import transformer_engine_torch as tex
15

16
from transformer_engine.common.recipe import Recipe
17
from transformer_engine.pytorch import torch_version
18

19
from .base import (
20
21
    fill_userbuffers_buffer_for_all_gather,
    get_dummy_wgrad,
22
    get_ub,
23
    get_workspace,
24
25
26
27
28
    TransformerEngineBaseModule,
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
29
from ._common import noop_cat, WeightGradStore
30
from ..quantization import FP8GlobalStateManager
31
32
from ..utils import (
    cast_if_needed,
33
    clear_tensor_data,
34
    divide,
35
    init_method_constant,
36
37
    requires_grad,
    needs_quantized_gemm,
38
    assert_dim_for_fp8_exec,
39
    assert_dim_for_all_gather,
40
41
    nvtx_range_pop,
    nvtx_range_push,
42
43
44
45
46
)
from ..distributed import (
    set_tensor_model_parallel_attributes,
    get_distributed_world_size,
    allreduce,
47
    symmetric_all_reduce,
48
49
    reduce_scatter_along_first_dim,
    gather_along_first_dim,
50
    is_fp8_activation_recompute_enabled,
51
    in_fp8_activation_recompute_phase,
52
53
    _fsdp_scatter_tensors,
    _fsdp_gather_tensors,
54
55
)
from ..cpp_extensions import (
56
    general_gemm,
57
)
58
from ..constants import GemmParallelModes, dist_group_type
59
from ..jit import no_torch_dynamo
60
from ..graph import is_graph_capturing
61
from ..quantized_tensor import (
62
    QuantizedTensor,
63
    QuantizedTensorStorage,
64
65
66
67
    Quantizer,
    prepare_for_saving,
    restore_from_saved,
)
68
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
69
from ..tensor.mxfp8_tensor import MXFP8Quantizer
70
from ..tensor.utils import is_custom
71
from ..export import is_in_onnx_export_mode, assert_warmed_up
72
73
74
75
76
77
from ..cpu_offload import (
    is_cpu_offload_enabled,
    start_offload,
    mark_not_offload,
    mark_activation_offload,
)
78
from ...debug.pytorch.debug_state import TEDebugState
79

80
81
82
83
84
85
86
87
88
89
90
__all__ = ["Linear"]


class _Linear(torch.autograd.Function):
    """Linear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
91
        weight: torch.Tensor,
92
        inp: torch.Tensor,
93
        bias: Optional[torch.Tensor],
94
95
96
        is_first_microbatch: Union[bool, None],
        fp8: bool,
        fp8_calibration: bool,
97
        wgrad_store: WeightGradStore,
98
99
100
101
        input_quantizer: Optional[Quantizer],
        weight_quantizer: Optional[Quantizer],
        output_quantizer: Optional[Quantizer],
        grad_input_quantizer: Optional[Quantizer],
102
103
        grad_weight_quantizer: Optional[Quantizer],
        grad_output_quantizer: Optional[Quantizer],
104
        fuse_wgrad_accumulation: bool,
105
        cpu_offloading: bool,
106
107
108
109
110
111
112
        tp_group: Union[dist_group_type, None],
        tp_size: int,
        sequence_parallel: bool,
        tensor_parallel: bool,
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
        is_grad_enabled: bool,
113
114
115
116
117
118
        ub_overlap_rs_fprop: bool,
        ub_overlap_ag_dgrad: bool,
        ub_overlap_ag_fprop: bool,
        ub_overlap_rs_dgrad: bool,
        ub_bulk_dgrad: bool,
        ub_bulk_wgrad: bool,
119
        ub_name: str,
120
        fp8_output: bool,  # pylint: disable=unused-argument
121
        fsdp_group: Union[dist_group_type, None],
122
123
        module: torch.nn.Module,
        skip_fp8_weight_update: bool,
124
        symmetric_ar_type: str,
125
        save_original_input: bool = False,
126
        debug: Optional[bool] = False,
127
    ) -> torch.Tensor:
128
        # pylint: disable=missing-function-docstring
129

130
131
132
133
134
        # NVTX label for profiling
        nvtx_label = "transformer_engine._Linear.forward"
        if ub_name is not None:
            nvtx_label = f"{nvtx_label}.{ub_name}"

135
        # Make sure input dimensions are compatible
136
        out_features, in_features = weight.shape
137
        assert inp.shape[-1] == in_features, "GEMM not possible"
138

139
        # Configure tensor-parallel communication
140
        tp_world_size = get_distributed_world_size(tp_group)
141
142
143
144
        backward_needs_input = is_grad_enabled and weight.requires_grad
        with_input_all_gather_nccl = (
            parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
        )
145
146

        # Configure Userbuffers communication (comm+GEMM overlap)
147
148
149
150
151
152
        if debug:  # turn off userbuffers in debug mode
            ub_overlap_rs_fprop = False
            ub_overlap_ag_fprop = False
            ub_overlap_rs_dgrad = False
            ub_bulk_wgrad = False
            ub_bulk_dgrad = False
153
154
155
        ub_obj = None
        ub_type = None
        if ub_overlap_rs_fprop:
156
            ub_obj = get_ub(ub_name + "_fprop", fp8)
157
158
            ub_type = tex.CommOverlapType.RS
        elif ub_overlap_ag_fprop:
159
            ub_obj = get_ub(ub_name + "_fprop", fp8)
160
161
            ub_type = tex.CommOverlapType.AG

162
163
        # custom recipe check
        custom = is_custom(input_quantizer) or is_custom(weight_quantizer)
164

165
166
167
168
169
170
171
172
        # ------------------------------------------------------
        # Prepare input tensor
        # Note: Cast to expected dtype and perform tensor-parallel communication
        # ------------------------------------------------------
        nvtx_range_push(f"{nvtx_label}.input_cast_comm")
        inputmat = inp  # Input tensor to save for backward (maybe sharded)
        inputmat_total = None  # Input tensor to pass to GEMM (gathered)
        own_quantized_input = False
173
        if fp8:
174
            assert_dim_for_fp8_exec(inputmat, weight)
175
            assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer)
176
177
178
179
180
            if save_original_input:
                assert not isinstance(
                    input_quantizer, Float8Quantizer
                ), "DelayedScaling recipe is not supported with save_original_input"

181
182
183
184
185
186
        if with_input_all_gather_nccl or ub_overlap_ag_fprop:  # All-gather input tensor

            # Cast local input tensor if needed
            if fp8 or debug:
                if input_quantizer is None:
                    raise ValueError("Missing quantizer for input tensor")
187
                if not isinstance(inputmat, QuantizedTensorStorage) and not custom:
188
189
                    own_quantized_input = True
                    input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
190
191
192
193
194
                    if isinstance(
                        input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
                    ):
                        # All-gather is not supported with FP8 column-wise data
                        input_quantizer.set_usage(columnwise=False)
195
196
197
198
199
                    if save_original_input:
                        # No need for column-wise data since this
                        # tensor will not be cached for backward pass
                        input_quantizer.set_usage(columnwise=False)
                        own_quantized_input = False
200
                    inputmat = input_quantizer(inputmat)
201
            else:
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
                inputmat = cast_if_needed(inp, activation_dtype)  # Cast for AMP

            # Initialize gathered input tensor
            quantizer = None
            if fp8 or debug:
                quantizer = input_quantizer
                quantizer.set_usage(rowwise=True, columnwise=False)
            if with_input_all_gather_nccl:  # Perform NCCL all-gather
                inputmat_total, _ = gather_along_first_dim(
                    inputmat,
                    tp_group,
                    quantizer=quantizer,
                )
            elif ub_overlap_ag_fprop:  # Initialize Userbuffers all-gather
                inputmat_total, _ = fill_userbuffers_buffer_for_all_gather(
                    ub_obj,
                    inputmat,
                    quantizer,
                    tp_group,
                )

        else:  # Do not all-gather input tensor
            if fp8 or debug:
225
                if isinstance(inputmat, QuantizedTensorStorage):
226
                    inputmat.update_usage(rowwise_usage=True)
227
                else:
228
229
                    if input_quantizer is None:
                        raise ValueError("Missing quantizer for input tensor")
230
231
232
                    input_quantizer.set_usage(
                        rowwise=True, columnwise=backward_needs_input and not save_original_input
                    )
233
                    inputmat = input_quantizer(inputmat)
234
                    own_quantized_input = True
235
            else:
236
237
                inputmat = cast_if_needed(inp, activation_dtype)  # Cast for AMP
            inputmat_total = inputmat
238
239
240

        if is_cpu_offload_enabled():
            start_offload(inputmat)
241
        nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
242
243
244
        # ------------------------------------------------------
        # Input tensor is ready for GEMM...
        # ------------------------------------------------------
245

246
247
248
        # ------------------------------------------------------
        # Prepare weight tensor
        # ------------------------------------------------------
249
250
        weightmat = weight
        if fp8 or debug:
251
            # Configure quantizer
252
253
            # No need to set the quantizer states if weight is already quantized
            if weight_quantizer is not None and not isinstance(weight, QuantizedTensor):
254
255
256
257
258
259
260
                columnwise_usage = is_grad_enabled and inp.requires_grad
                if not columnwise_usage:
                    columnwise_usage = (
                        is_fp8_activation_recompute_enabled()
                        and not in_fp8_activation_recompute_phase()
                    )
                weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
261
262
263
            elif isinstance(weight, QuantizedTensor):
                # If weight is already quantized, no need to set quantizer states
                weight_quantizer = weight._quantizer
264
            # Get quantized weight
265
266
267
268
269
270
271
272
            update_workspace = is_first_microbatch is None or is_first_microbatch
            weightmat = module.get_weight_workspace(
                tensor=weight,
                quantizer=weight_quantizer,
                cache_name=(None if is_first_microbatch is None else "weight"),
                update_workspace=update_workspace,
                skip_update_flag=skip_fp8_weight_update,
                fsdp_group=fsdp_group,
273
                workspace_dtype=activation_dtype,
274
            )
275
276
            weightmat.update_usage(rowwise_usage=True)

277
        else:
278
279
280
281
            weightmat = cast_if_needed(weightmat, activation_dtype)  # Cast for AMP
        # ------------------------------------------------------
        # Weight tensor is ready for GEMM...
        # ------------------------------------------------------
282
283
284

        # Cast bias to expected dtype
        bias_dtype = activation_dtype
285
        if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32:
286
            # cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
287
288
289
290
291
292
293
294
295
296
            bias_dtype = torch.bfloat16
        bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias

        # Calibrate quantizers if needed
        if not fp8 and fp8_calibration:
            if input_quantizer is not None:
                input_quantizer.calibrate(inputmat_total)
            if weight_quantizer is not None:
                weight_quantizer.calibrate(weight)

297
298
        # Choose whether to use GEMM kernel with split accumulator
        use_split_accumulator = _2X_ACC_FPROP
299
300
301
        if fp8:
            recipe = FP8GlobalStateManager.get_fp8_recipe()
            if hasattr(recipe, "fp8_gemm_fprop"):
302
303
304
305
306
                use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator

        # Configure output quantizer
        if output_quantizer is not None:
            output_quantizer.set_usage(rowwise=True, columnwise=False)
307

308
309
310
311
312
313
314
315
316
317
318
319
320
321
        # Output buffer for Userbuffers reduce-scatter
        reduce_scatter_out = None
        if ub_overlap_rs_fprop:
            out_shape = list(inp.shape)
            out_shape[0] //= tp_world_size
            out_shape[-1] = out_features
            reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device)

        # ------------------------------------------------------
        # Forward GEMM
        # Note: y = x * w^T
        # ------------------------------------------------------
        nvtx_range_push(f"{nvtx_label}.gemm")
        gemm_out, *_, reduce_scatter_out = general_gemm(
322
323
324
325
            weightmat,
            inputmat_total,
            get_workspace(),
            quantization_params=output_quantizer,
326
            out_dtype=activation_dtype,
327
            bias=bias,
328
            use_split_accumulator=use_split_accumulator,
329
330
            ub=ub_obj,
            ub_type=ub_type,
331
            extra_output=reduce_scatter_out,
332
        )
333
        nvtx_range_pop(f"{nvtx_label}.gemm")
334
335
336
337
        # ------------------------------------------------------
        # Finished forward GEMM...
        # ------------------------------------------------------

338
339
340
341
342
343
344
        # Deallocate GEMM input tensor if no longer needed
        # TODO(yuzhongw, tmoon): Figure out why inputmat_total is not automatically
        # deallocated by GC. Manually deallocating is a temporary hack.
        if with_input_all_gather_nccl:
            clear_tensor_data(inputmat_total)
            inputmat_total = None

345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
        # ------------------------------------------------------
        # Prepare output tensor
        # Note: Perform tensor-parallel communication
        # ------------------------------------------------------
        out = None
        if ub_overlap_rs_fprop:
            out = reduce_scatter_out
        elif parallel_mode == "row" and tp_size > 1:
            nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
            out = gemm_out
            if sequence_parallel:
                out, _ = reduce_scatter_along_first_dim(out, tp_group)
            elif tensor_parallel:
                if symmetric_ar_type is not None:
                    out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
                else:
                    out, _ = allreduce(out, tp_group)
            nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
        else:
            out = gemm_out
        # ------------------------------------------------------
        # Output tensor is ready to return...
        # ------------------------------------------------------

        # ------------------------------------------------------
        # Cache state for backward pass
        # ------------------------------------------------------
372
373

        if is_grad_enabled:
374
375
376
            if save_original_input:
                inputmat = inp

377
            ctx.weight_quantizer = weight_quantizer
378
379
380
381
382

            ctx.backward_input_needs_gather = (
                weight.requires_grad and parallel_mode == "column" and sequence_parallel
            )

383
384
385
386
            # Discard unneeded data in input tensor
            if (
                backward_needs_input
                and own_quantized_input
387
                and isinstance(inputmat, QuantizedTensorStorage)
388
            ):
389
390
391
                if (
                    ctx.backward_input_needs_gather
                    and weight_quantizer.supports_only_rowwise_all_gather()
392
393
394
395
396
397
398
399
400
                ):
                    # All-gather is not supported with FP8 column-wise data
                    inputmat.update_usage(rowwise_usage=True, columnwise_usage=False)
                else:
                    # Discard row-wise data since it is not needed in backward pass
                    inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)

            # Cached input tensor
            saved_inputmat = None
401
402
            if backward_needs_input:
                saved_inputmat = inputmat
403

404
405
            if cpu_offloading and saved_inputmat is not None:
                mark_activation_offload(saved_inputmat)
406

407
408
            # Scatter intermediate/activation tensors saved for the backward pass
            # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights
409
            nvtx_range_push(f"{nvtx_label}.fsdp_scatter")
410
411
412
            ctx.fsdp_group = fsdp_group
            ctx.fsdp_shapes = _fsdp_scatter_tensors(
                fsdp_group,
413
                saved_inputmat,
414
                weightmat if fp8 and not isinstance(weight, QuantizedTensorStorage) else None,
415
            )
416
            nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
417

evt_fugx1's avatar
evt_fugx1 committed
418
            if cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
dongchl's avatar
dongchl committed
419
                ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
420

dongchl's avatar
dongchl committed
421
                if ctx.grad_added_to_main_grad:
422
423
424
425
426
427
428
                    # If you are passing torch.nn.Parameter through the Torch hooks, you will
                    # get back torch.Tensor. Torch rips off the Parameter wrapper.
                    # You need to preserve the weight object to have all the attributes user
                    # sets for the weights. Because of this, it is not recommended to offload
                    # weights if weights are externally touched outside this module
                    ctx.weight_object = weight

429
            mark_not_offload(weight, weightmat, bias)
430
431
            # TODO(ksivamani): Check memory usage
            tensors_to_save, tensor_objects = prepare_for_saving(
432
                saved_inputmat,
433
                weightmat,
434
                weight,
435
                bias,
436
            )
437
438
            ctx.save_for_backward(*tensors_to_save)
            ctx.tensor_objects = tensor_objects
439

440
441
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
442
            ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
443
444
            ctx.input_quantizer = input_quantizer
            ctx.grad_input_quantizer = grad_input_quantizer
445
446
            ctx.grad_weight_quantizer = grad_weight_quantizer
            ctx.grad_output_quantizer = grad_output_quantizer
447
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
448
            if fuse_wgrad_accumulation and weight.requires_grad:
449
450
451
452
453
454
455
456
                # This check is needed to ensure that main_grad is not created
                # during the forward pass when using MCore FSDP as it creates
                # the main_grad buffer lazily before backprop
                if hasattr(weight, "__fsdp_param__"):
                    # MCore FSDP creates main_grad lazily before backward
                    ctx.main_grad_func = weight.get_main_grad
                else:
                    ctx.main_grad_func = lambda: weight.main_grad
457

458
            ctx.debug = debug
459
            ctx.custom = custom
460
            ctx.cpu_offloading = cpu_offloading
461
            ctx.is_first_microbatch = is_first_microbatch
462
            ctx.use_bias = bias is not None
463
464
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
465
            ctx.inp_shape = inp.shape
466
467
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
468
469
470
471
            ctx.ub_overlap_ag = ub_overlap_ag_dgrad
            ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
            ctx.ub_bulk_dgrad = ub_bulk_dgrad
            ctx.ub_bulk_wgrad = ub_bulk_wgrad
472
            ctx.ub_name = ub_name
473
474
            ctx.tp_size = tp_size
            ctx.requires_dgrad = inp.requires_grad
475
            ctx.requires_wgrad = weight.requires_grad
476
            ctx.reduce_and_update_bwd_fp8_tensors = False
477

478
            ctx.owns_input = saved_inputmat is not inp
479
            if ctx.fp8 and requires_grad(inp, weight, bias):
480
481
482
483
                _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
                ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
                if in_fp8_activation_recompute_phase():
                    FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
484
            ctx.wgrad_store = wgrad_store
485

486
487
488
        # ------------------------------------------------------
        # Cached state for backward pass is ready...
        # ------------------------------------------------------
489

490
        return out
491
492

    @staticmethod
493
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
494
        # pylint: disable=missing-function-docstring
495

496
497
498
499
500
        # NVTX label for profiling
        nvtx_label = "transformer_engine._Linear.backward"
        if ctx.ub_name is not None:
            nvtx_label = f"{nvtx_label}.{ctx.ub_name}"

501
        with torch.cuda.nvtx.range("_Linear_backward"):
502
503
504
505
            saved_tensors = ctx.saved_tensors
            inputmat, weight_fp8, weight, bias = (  # pylint: disable=unbalanced-tuple-unpacking
                restore_from_saved(ctx.tensor_objects, saved_tensors)
            )
506

507
508
509
            # Delete the references to tensor objects once they've been consumed
            # by the `restore_from_saved` method to construct back the actual tensors.
            ctx.tensor_objects = None
510
511
512

            # Since main_grad can be modified inplace, it should not be a part of saved_tensors
            main_grad = (
513
                ctx.main_grad_func()
514
515
516
517
                if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad
                else None
            )

dongchl's avatar
dongchl committed
518
519
            if ctx.cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
                if ctx.grad_added_to_main_grad:
520
                    weight = ctx.weight_object
dongchl's avatar
dongchl committed
521

522
523
            if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
                weight.main_grad = main_grad
524

525
526
527
            # Gather intermediate/activation tensors if needed
            # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
            #       shards/unshards the base weights so we don't do it ourselves
528
            nvtx_range_push(f"{nvtx_label}.fsdp_gather")
529
530
531
532
            _fsdp_gather_tensors(
                ctx.fsdp_group,
                ctx.fsdp_shapes,
                inputmat,
533
                weight_fp8,
534
            )
535
            nvtx_range_pop(f"{nvtx_label}.fsdp_gather")
536

537
            # Configure Userbuffers communication (comm+GEMM overlap)
538
            ctx.ub_obj_gradout = None
539
            ub_obj_dgrad = None
540
            ub_obj_wgrad = None
541
542
            ub_type_dgrad = None
            ub_type_wgrad = None
543
            dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
544
            if ctx.ub_overlap_ag:
545
                # Overlap grad_output all-gather with dgrad compute
546
                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
547
548
                ub_obj_dgrad = ctx.ub_obj_gradout
                ub_type_dgrad = tex.CommOverlapType.AG
549
550
            elif ctx.ub_overlap_rs_dgrad:
                # Overlap dgrad reduce-scatter with dgrad compute
551
                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
552
553
                ub_obj_dgrad = ctx.ub_obj_gradout
                ub_type_dgrad = tex.CommOverlapType.RS
554
555
556
            else:
                if ctx.ub_bulk_dgrad:
                    # Overlap inputmat all-gather with dgrad compute
557
                    ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
558
559
                    ub_obj_dgrad = ctx.ub_obj_gradout
                    ub_type_dgrad = tex.CommOverlapType.AG
560
561
                if ctx.ub_bulk_wgrad:
                    # Overlap dgrad reduce-scatter with wgrad compute
562
                    ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
563
                    ub_type_wgrad = tex.CommOverlapType.RS
564
565
566
567
568
569
570
571

            # --------------------------------------------------
            # Prepare grad output tensor
            # Note: Cast to expected dtype and perform tensor-parallel communication
            # --------------------------------------------------

            # Unmodified grad output tensor
            grad_output_arg = grad_output
572

573
574
575
576
            # Configure quantizer for grad output tensor
            # Note: dgrad GEMM requires row-wise usage, wgrad GEMM
            # requires column-wise usage
            if ctx.grad_output_quantizer is not None:
577
578
579
580
581
582
583
                quantizer = ctx.grad_output_quantizer
                quantizer.set_usage(rowwise=True, columnwise=True)
                if ctx.ub_overlap_ag:
                    # Userbuffers only supports communication for one
                    # tensor usage at a time. Configure quantizer with
                    # usage for only dgrad GEMM.
                    quantizer.set_usage(columnwise=False)
584

585
586
587
588
589
590
591
592
593
594
595
596
597
            # Adjust the quantization direction approach depending
            # on whether wgrad calculations will be performed.
            # NOTE: If requires_dgrad is False, disabling `rowwise` quantization and keeping `columnwise` quantization
            #       results in `Assertion failed: output_tensor->has_data(). Quantizing in only the columnwise direction not supported yet!`
            # NOTE: For `ctx.bias is True`, selected quantize kernel errors with
            #       `cast_kernels.cuh:1322 in function fp8_quantize_arch_l_100: Not implemented scaling mode or fusion: NVTE_DELAYED_TENSOR_SCALING or IS_DBIAS=true on GPU with compute capability < 10.0.`
            if (
                not ctx.use_bias
                and not ctx.requires_wgrad
                and ctx.grad_output_quantizer is not None
            ):
                ctx.grad_output_quantizer.set_usage(columnwise=False)

598
            # Prepare grad output tensor
599
            nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
600
601
602
603
            (
                grad_output,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
604
605
606
607
                ctx,
                grad_output,
                ctx.parallel_mode == "row",
                ctx.grad_output_quantizer,
608
            )
609
            nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
610

611
612
613
614
615
616
617
618
619
620
            # --------------------------------------------------
            # Grad output tensor is ready for computing grad input...
            # --------------------------------------------------

            # --------------------------------------------------
            # Prepare input tensor
            # Note: Input tensor is needed for wgrad GEMM.
            # Tensor-parallel communication is overlapped with dgrad
            # GEMM.
            # --------------------------------------------------
621
            inputmat_total = None
622
            inputmat_total_work = None
623
624
            if ctx.requires_wgrad:
                if ctx.fp8 or ctx.debug:
625
                    if isinstance(inputmat, QuantizedTensorStorage):
626
627
                        # Input tensor is already quantized
                        pass
628
                    elif ctx.debug or ctx.custom:
629
630
631
632
                        # Debug quantizer will be applied immediately before wgrad GEMM
                        pass
                    else:
                        # Quantize input tensor
633
                        quantizer = ctx.input_quantizer
634
                        if quantizer.supports_only_rowwise_all_gather():
635
                            # All-gather is not supported with FP8 column-wise data
636
637
638
639
640
641
642
643
                            quantizer.set_usage(
                                rowwise=True,
                                columnwise=not ctx.backward_input_needs_gather,
                            )
                        else:
                            quantizer.set_usage(rowwise=False, columnwise=True)
                        inputmat = quantizer(inputmat)
                else:
644
                    if isinstance(inputmat, QuantizedTensorStorage):
645
646
647
                        inputmat = inputmat.dequantize(dtype=ctx.activation_dtype)
                    else:
                        inputmat = cast_if_needed(inputmat, ctx.activation_dtype)
648
            if ctx.backward_input_needs_gather:
649
                quantizer = None
650
                if ctx.fp8 or ctx.debug:
651
                    quantizer = ctx.input_quantizer
652
                    if quantizer.supports_only_rowwise_all_gather():
653
654
655
656
657
                        # If data is in FP8, we compute FP8 transposes manually
                        quantizer.set_usage(rowwise=True, columnwise=False)
                    else:
                        # wgrad GEMM requires input with column-wise usage
                        quantizer.set_usage(rowwise=False, columnwise=True)
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
                if ctx.ub_bulk_dgrad:
                    inputmat_total, _ = fill_userbuffers_buffer_for_all_gather(
                        ub_obj_dgrad,
                        inputmat,
                        quantizer,
                        ctx.tp_group,
                    )
                else:
                    nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
                    inputmat_total, inputmat_total_work = gather_along_first_dim(
                        inputmat,
                        ctx.tp_group,
                        async_op=True,
                        quantizer=quantizer,
                    )
                    nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
674
675
            else:
                inputmat_total = inputmat
676
677
678
            # --------------------------------------------------
            # Input tensor is ready for computing grad weight...
            # --------------------------------------------------
679

680
            # --------------------------------------------------
681
            # Compute grad input tensor
682
683
            # --------------------------------------------------

684
685
            dgrad = None
            dgrad_work = None
686
            if ctx.requires_dgrad:
687

688
                # Make sure required data is available
689
                if isinstance(grad_output, QuantizedTensorStorage):
690
                    grad_output.update_usage(rowwise_usage=True)
691
692
693
                if ctx.weight_quantizer is not None and isinstance(
                    weight_fp8, QuantizedTensorStorage
                ):
694
695
696
697
                    weight_fp8.update_usage(columnwise_usage=True)

                # Choose whether to use GEMM kernel with split accumulator
                use_split_accumulator = _2X_ACC_DGRAD
698
699
700
                if ctx.fp8:
                    recipe = ctx.fp8_recipe
                    if hasattr(recipe, "fp8_gemm_dgrad"):
701
                        use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
702

703
704
705
706
707
708
709
710
711
712
                # Update grad input quantizer
                if ctx.grad_input_quantizer is not None:
                    ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)

                # Output buffers for Userbuffers reduce-scatter
                gemm_out = None
                reduce_scatter_out = None
                if ctx.ub_overlap_rs_dgrad:
                    reduce_scatter_out = torch.empty(
                        dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device
713
                    )
714
715
                elif ctx.ub_bulk_wgrad:
                    gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False)
716

717
718
                # dgrad GEMM
                # Note: dx = dy * w
719

720
721
                nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
                gemm_out, *_, reduce_scatter_out = general_gemm(
722
723
724
725
726
727
                    weight_fp8,
                    grad_output,
                    get_workspace(),
                    layout="NN",
                    grad=True,
                    quantization_params=ctx.grad_input_quantizer,
728
                    out=gemm_out,
729
                    out_dtype=ctx.activation_dtype,
730
                    use_split_accumulator=use_split_accumulator,
731
732
                    ub=ub_obj_dgrad,
                    ub_type=ub_type_dgrad,
733
                    extra_output=reduce_scatter_out,
734
735
                    bulk_overlap=ctx.ub_bulk_dgrad,
                )
736
                nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
737

738
739
                # Prepare grad input tensor
                # Note: Perform tensor-parallel communication
740
                if ctx.ub_overlap_rs_dgrad:
741
742
743
744
                    dgrad = reduce_scatter_out
                elif ctx.ub_bulk_wgrad:
                    dgrad = ub_obj_wgrad.get_buffer(local_chunk=True)
                elif ctx.parallel_mode == "column" and ctx.tp_size > 1:
745
                    nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
746
                    dgrad = gemm_out
747
748
749
750
751
                    if ctx.sequence_parallel:
                        dgrad, dgrad_work = reduce_scatter_along_first_dim(
                            dgrad,
                            ctx.tp_group,
                            async_op=True,
752
                        )
753
                    else:
754
                        dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
755
                    nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
756
757
758
759
760
761
762
763
764
765
                else:
                    dgrad = gemm_out

            # --------------------------------------------------
            # Grad input tensor has been computed...
            # --------------------------------------------------

            # --------------------------------------------------
            # Compute grad weight
            # --------------------------------------------------
766

767
768
            wgrad = None
            if ctx.requires_wgrad:
769

770
771
772
                # Prepare input tensor
                # Note: Synchronize tensor-parallel communication and
                # make sure required data is available
773
774
775
                if inputmat_total_work is not None:
                    inputmat_total_work.wait()
                    inputmat_total_work = None
776
                if ctx.fp8 or ctx.debug:
777
                    if isinstance(inputmat_total, QuantizedTensorStorage):
778
779
780
781
782
783
784
785
786
                        inputmat_total.update_usage(columnwise_usage=True)
                    else:
                        ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
                        inputmat_total = ctx.input_quantizer(inputmat_total)

                # Prepare grad output tensor
                # Note: Synchronize tensor-parallel communication and
                # make sure required data is available
                if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
787
                    # UB does not support pipelined overlapping grad output
788
789
790
                    # all-gather with wgrad GEMM. Also, we can't
                    # convert row-scaled MXFP8 to column-scaled, so we
                    # can't reuse the grad output that was gathered
791
                    # for the dgrad GEMM. We work around by explicitly
792
793
794
795
796
797
                    # overlapping the AG operation with the dgrad GEMM.

                    # Get the communication stream from the dgrad GEMM to use for the AG
                    dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream()

                    # This object is separate from the ub_obj_wgrad object which is passed to the GEMM
798
                    ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
799

800
                    ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
801
802
803
804
805
806
807

                    # We use the send stream to copy into the userbuffers.
                    # This is the same stream that we will use to access the data in the AG,
                    # so we dont need to add any syncs yet.
                    with torch.cuda.stream(dgrad_send_stream):
                        grad_output, _ = fill_userbuffers_buffer_for_all_gather(
                            ub_obj_overlap_wgrad,
808
                            grad_output_arg,
809
                            ctx.grad_output_quantizer,
810
811
                            ctx.tp_group,
                        )
812
813
814
815
816

                    # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm
                    tex.bulk_overlap_ag_with_external_gemm(
                        ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream
                    )
817

818
                if ctx.fp8 or ctx.debug:
819
                    if isinstance(grad_output, QuantizedTensorStorage):
820
821
822
823
                        grad_output.update_usage(columnwise_usage=True)
                    else:
                        ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
                        grad_output = ctx.grad_output_quantizer(grad_output)
824

825
826
827
828
829
830
831
                # Figure out whether to use split accumulator
                use_split_accumulator = _2X_ACC_WGRAD
                if ctx.fp8:
                    recipe = ctx.fp8_recipe
                    if hasattr(recipe, "fp8_gemm_wgrad"):
                        use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator

832
833
834
835
836
837
838
839
840
                # Figure out whether to output wgrad GEMM directly into main grad
                if ctx.is_first_microbatch is not None:
                    accumulate_wgrad_into_param_main_grad = (
                        ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                    )
                else:
                    accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation

                # Output buffer for overlapping FP8 grad input
841
                # reduce-scatter with wgrad GEMM
842
                reduce_scatter_out = None
843
                if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
844
845
                    reduce_scatter_out = torch.empty(
                        dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device
846
847
                    )

848
849
850
851
                # Arguments to include in wgrad GEMM closure
                wgrad_gemm_kwargs = {
                    "workspace": get_workspace(),
                    "out_dtype": (
852
853
                        main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
                    ),
854
                    "quantization_params": ctx.grad_weight_quantizer,
855
856
857
858
859
                    "accumulate": (
                        accumulate_wgrad_into_param_main_grad
                        if not getattr(weight, "overwrite_main_grad", False)
                        else False
                    ),
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
                    "layout": "NT",
                    "out": main_grad if ctx.fuse_wgrad_accumulation else None,
                    "bias": (bias if (grad_bias is None and not ctx.fp8) else None),
                    "use_split_accumulator": use_split_accumulator,
                    "grad": True,
                    "ub": ub_obj_wgrad,
                    "ub_type": ub_type_wgrad,
                    "extra_output": reduce_scatter_out,
                    "bulk_overlap": ctx.ub_bulk_wgrad,
                }

                def wgrad_gemm(
                    x: torch.Tensor,
                    dy: torch.Tensor,
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
                    """Perform wgrad GEMM: dw = dy^T * x

                    May be fused with bgrad computation.

                    May be called outside of this function to enable
                    some advanced communication/compute overlapping.

                    """
                    nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
                    dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs)
                    nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
                    return dw, db

                # Choose whether to call wgrad GEMM now or delay
889
                if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
890
891
892
893
894
895
896
897
898
899
900
                    if (
                        wgrad_gemm_kwargs["ub"] is not None
                        or wgrad_gemm_kwargs["ub_type"] is not None
                        or wgrad_gemm_kwargs["extra_output"] is not None
                        or wgrad_gemm_kwargs["bulk_overlap"]
                    ):
                        raise NotImplementedError(
                            "Delayed weight grad computation is not supported "
                            "with Userbuffers (tensor-parallel communication overlapping)"
                        )
                    ctx.wgrad_store.put([inputmat_total, grad_output], wgrad_gemm)
901
902
903
                    # overlap_grad_reduce, dongcl
                    if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
                        wgrad = torch.empty(weight.size(), dtype=ctx.activation_dtype, device=weight.device)
904
905
                else:

906
907
908
909
                    # Call wgrad GEMM now
                    wgrad, grad_bias_ = wgrad_gemm(inputmat_total, grad_output)

                    # Update grad bias if needed
910
911
912
913
                    if grad_bias is None:
                        grad_bias = grad_bias_
                    del grad_bias_

914
                    # Deallocate tensors if permitted
915
                    if ctx.owns_input:
916
                        # Input tensor is internal
917
                        clear_tensor_data(inputmat_total)
918
919
                    elif ctx.backward_input_needs_gather:
                        # Gathered input tensor is internal
920
                        clear_tensor_data(inputmat_total)
921
922
923
                    if ctx.parallel_mode == "row" and ctx.sequence_parallel:
                        # Gathered grad output tensor is internal
                        clear_tensor_data(grad_output)
924

925
                # Update grad input if overlapping reduce-scatter with wgrad GEMM
926
927
                if ctx.ub_bulk_wgrad:
                    if ub_obj_wgrad.is_fp8_ubuf():
928
                        dgrad = reduce_scatter_out
929
                    else:
930
931
932
933
934
                        dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone()

            # --------------------------------------------------
            # Grad weight has been computed...
            # --------------------------------------------------
935

936
            # Don't return grad bias if not needed
937
938
939
            if not ctx.use_bias:
                grad_bias = None

940
            # Make sure all tensor-parallel communication is finished
941
942
943
944
945
946
947
948
            if inputmat_total_work is not None:
                inputmat_total_work.wait()
                inputmat_total_work = None
            if dgrad_work is not None:
                dgrad_work.wait()
                dgrad_work = None

        if ctx.requires_wgrad:
949
            # Handle custom DDP from mcore.
950
951
952
953
954
            if (
                ctx.fuse_wgrad_accumulation
                and weight is not None
                and hasattr(weight, "grad_added_to_main_grad")
            ):
955
                weight.grad_added_to_main_grad = True
956
                if getattr(weight, "zero_out_wgrad", False):
957
958
959
960
                    wgrad = get_dummy_wgrad(
                        list(weight.main_grad.shape),
                        weight.dtype,
                        zero=True,
961
                    )
962
                else:
963
964
965
                    wgrad = get_dummy_wgrad(
                        list(weight.main_grad.shape),
                        weight.dtype,
966
                    )
967
968
969
970
            elif ctx.fuse_wgrad_accumulation:
                wgrad = None
        else:
            wgrad = None
971

972
        # Update FP8 scaling factors if needed
973
        if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
974
            nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
975
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
976
            nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
977

978
        # Scatter fp8 weight buffers
979
        if ctx.fp8 and not isinstance(weight, QuantizedTensorStorage):
980
            _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
981
        return (
982
            wgrad,
983
984
            dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
            grad_bias,
985
986
987
            None,  # is_first_microbatch
            None,  # fp8
            None,  # fp8_calibration
988
            None,  # wgrad_store
989
990
991
992
            None,  # input_quantizer
            None,  # weight_quantizer
            None,  # output_quantizer
            None,  # grad_input_quantizer
993
994
            None,  # grad_weight_quantizer
            None,  # grad_output_quantizer
995
996
997
998
999
1000
1001
1002
1003
            None,  # fuse_wgrad_accumulation
            None,  # cpu_offloading
            None,  # tp_group
            None,  # tp_size
            None,  # sequence_parallel
            None,  # tensor_parallel
            None,  # activation_dtype
            None,  # parallel_mode
            None,  # is_grad_enabled
1004
1005
1006
1007
1008
1009
            None,  # ub_overlap_rs_fprop
            None,  # ub_overlap_ag_dgrad
            None,  # ub_overlap_ag_fprop
            None,  # ub_overlap_rs_dgrad
            None,  # ub_bulk_dgrad
            None,  # ub_bulk_wgrad
1010
            None,  # ub_name
1011
            None,  # fp8_output
1012
            None,  # fsdp_group
1013
1014
            None,  # module
            None,  # skip_fp8_weight_update
1015
            None,  # symmetric_ar_type
1016
            None,  # save_original_input
1017
            None,  # debug
1018
1019
1020
1021
        )


class Linear(TransformerEngineBaseModule):
1022
    """Applies a linear transformation to the incoming data :math:`y = xA^T + b`
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036

    On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
1037
    get_rng_state_tracker : Callable, default = `None`
1038
                 used to get the random number generator state tracker for initializing weights.
1039
1040
    rng_tracker_name : str, default = `None`
                 the param passed to get_rng_state_tracker to get the specific rng tracker.
cyanguwa's avatar
cyanguwa committed
1041
    parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
1042
1043
1044
1045
1046
1047
1048
                      Configuration for splitting the weight and bias tensors along dim 0 into
                      multiple PyTorch parameters. If a list or tuple of strings is provided,
                      they are used to make the names of equally-sized parameters. If a dict
                      (preferably an OrderedDict) is provided, the keys are used as names and
                      values as split sizes along dim 0. The resulting parameters will have
                      names that end in `_weight` or `_bias`, so trailing underscores are
                      stripped from any provided names.
1049
    device : Union[torch.device, str], default = "cuda"
1050
          The device on which the parameters of the model will be allocated. It is the user's
1051
1052
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
1053
1054
    name: str, default = `None`
        name of the module, currently used for debugging purposes.
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
1068
    parallel_mode : {None, 'column', 'row'}, default = `None`
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
1080
1081
1082
                             size to accumulate gradients in. This argument along with
                             weight tensor having attribute 'overwrite_main_grad' set to True
                             will overwrite `main_grad` instead of accumulating.
1083
1084
1085
1086
1087
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
1088
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
1089
1090
1091
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
1092
1093
1094
1095
    delay_wgrad_compute : bool, default = `False`
                         Whether or not to delay weight gradient computation. If set to `True`,
                         it's the user's responsibility to call `module.backward_dw` to compute
                         weight gradients.
1096
1097
1098
1099
1100
    symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
                   Type of symmetric memory all-reduce to use during the forward pass.
                   This can help in latency bound communication situations.
                   Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
                   is used.
1101
1102
1103
1104
1105
    save_original_input : bool, default = `False`
                       If set to `True`, always saves the original input tensor rather than the
                       cast tensor. In some scenarios, the input tensor is used by multiple modules,
                       and saving the original input tensor may reduce the memory usage.
                       Cannot work with FP8 DelayedScaling recipe.
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
1117
        rng_tracker_name: Optional[str] = None,
1118
1119
1120
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
1121
        params_dtype: Optional[torch.dtype] = None,
1122
        parallel_mode: Optional[str] = None,
cyanguwa's avatar
cyanguwa committed
1123
        parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
1124
        device: Union[torch.device, str] = "cuda",
1125
        ub_overlap_ag: bool = False,
1126
        ub_overlap_rs: bool = False,
1127
        ub_overlap_rs_dgrad: bool = False,
1128
1129
        ub_bulk_dgrad: bool = False,
        ub_bulk_wgrad: bool = False,
1130
        ub_name: Optional[str] = None,
1131
        delay_wgrad_compute: bool = False,
1132
        symmetric_ar_type: Optional[str] = None,
1133
        save_original_input: bool = False,
1134
        name: Optional[str] = None,
1135
1136
    ) -> None:
        super().__init__()
1137
1138

        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
1139
1140
1141
1142
1143
1144
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
        self.apply_bias = bias and not return_bias
1145
1146
        self.get_rng_state_tracker = get_rng_state_tracker
        self.rng_tracker_name = rng_tracker_name
1147
        self.symmetric_ar_type = symmetric_ar_type
1148
        self.save_original_input = save_original_input
1149
1150
        self.name = name

1151
1152
        self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)

1153
1154
        if device == "meta":
            assert parameters_split is None, "Cannot split module parameters on 'meta' device."
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

1176
        # Column parallel TP overlap options
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
        self.ub_overlap_ag_fprop = (
            self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_ag
        )
        self.ub_overlap_rs_dgrad = (
            self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_rs_dgrad
        )
        self.ub_bulk_dgrad = (
            self.parallel_mode == "column"
            and self.sequence_parallel
            and ub_bulk_dgrad
            and not self.ub_overlap_rs_dgrad
        )
        self.ub_bulk_wgrad = (
            self.parallel_mode == "column"
            and self.sequence_parallel
            and ub_bulk_wgrad
            and not self.ub_overlap_rs_dgrad
        )
1195
1196

        # Row parallel TP overlap options
1197
1198
1199
1200
1201
1202
        self.ub_overlap_rs_fprop = (
            self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_rs
        )
        self.ub_overlap_ag_dgrad = (
            self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_ag
        )
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216

        if any(
            [
                self.ub_overlap_rs_fprop,
                self.ub_overlap_ag_dgrad,
                self.ub_overlap_ag_fprop,
                self.ub_overlap_rs_dgrad,
                self.ub_bulk_dgrad,
                self.ub_bulk_wgrad,
            ]
        ):
            assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized."
        self.ub_name = ub_name

1217
1218
1219
1220
1221
1222
1223
        if self.symmetric_ar_type is not None:
            assert torch_version() >= (
                2,
                7,
                0,
            ), "Torch version must be at least 2.7 to use symmetric memory"

1224
1225
1226
        # Initialize params in FP8
        with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()

1227
1228
1229
1230
1231
1232
1233
1234
        # Contiguous buffers for params
        weight_tensor = torch.empty(
            self.out_features,
            self.in_features,
            device=device,
            dtype=params_dtype,
        )
        bias_tensor = None
1235
        if self.use_bias:
1236
1237
1238
1239
1240
            bias_tensor = torch.empty(
                self.out_features,
                device=device,
                dtype=params_dtype,
            )
1241

1242
1243
1244
1245
        # Configure parameter splits
        self.weight_names = []
        self.bias_names = []
        self.parameter_split_sizes = []
1246
        if parameters_split is None:
1247
1248
1249
1250
1251
1252
            # Split into a single parameter by default
            self.weight_names = ["weight"]
            self.bias_names = ["bias"]
            self.parameter_split_sizes = [out_features]
        elif not parameters_split:
            raise ValueError("Cannot split weight buffer into 0 parameters")
cyanguwa's avatar
cyanguwa committed
1253
        elif isinstance(parameters_split, dict):
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
            # Split parameters with provided sizes
            for name, split_size in parameters_split.items():
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
        elif all(isinstance(name, str) for name in parameters_split):
            # Split parameters evenly
            split_size = out_features // len(parameters_split)
            for name in parameters_split:
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
cyanguwa's avatar
cyanguwa committed
1266
        else:
1267
            raise TypeError("Invalid configuration for parameters split")
1268

1269
1270
1271
1272
1273
1274
        # Make sure parameter splits are valid
        if sum(self.parameter_split_sizes) != out_features:
            raise ValueError(
                f"Trying to split weight buffer ({out_features=}) "
                f"with split sizes {self.parameter_split_sizes}"
            )
1275

1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
        # Adjust parameter splits for tensor-parallel distribution
        if self.parallel_mode == "column":
            for i, size in enumerate(self.parameter_split_sizes):
                if size % self.tp_size != 0:
                    raise RuntimeError(
                        f"Attempting to distribute a parameter with out_features={size} "
                        f"between {self.tp_size} tensor-parallel processes"
                    )
                self.parameter_split_sizes[i] = size // self.tp_size

1286
1287
1288
1289
1290
        # Construct weight parameters
        # Note: Register weights together so that they are adjacent to
        # each other in Linear.parameters(). This makes it more likely
        # that they will stay contiguous if the weights are
        # manipulated externally, e.g. by FSDP.
1291
1292
1293
1294
1295
1296
1297
1298
        offset = 0
        for i, split_size in enumerate(self.parameter_split_sizes):
            split_start = offset
            offset += split_size
            split_end = offset

            # Check if parameters are subviews of buffers
            is_subview = (split_start, split_end) != (0, self.out_features)
1299
            if is_subview and with_fp8_params:
1300
1301
1302
                raise RuntimeError(
                    "Splitting QuantizedTensor into multiple params is not supported"
                )
1303

1304
            # Construct weight parameter
1305
1306
1307
1308
1309
1310
1311
            self.register_parameter(
                self.weight_names[i],
                torch.nn.Parameter(weight_tensor[split_start:split_end]),
                init_fn=init_method,
                get_rng_state_tracker=get_rng_state_tracker,
                fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
            )
1312

1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
        # Construct bias parameters if needed
        if self.use_bias:
            offset = 0
            for i, split_size in enumerate(self.parameter_split_sizes):
                split_start = offset
                offset += split_size
                split_end = offset
                self.register_parameter(
                    self.bias_names[i],
                    torch.nn.Parameter(bias_tensor[split_start:split_end]),
                    init_fn=init_method_constant(0.0),
                )
        else:
            for name in self.bias_names:
                bias = torch.Tensor().to(dtype=params_dtype, device=device)
                setattr(self, name, bias)
cyanguwa's avatar
cyanguwa committed
1329

1330
        if with_fp8_params:
1331
1332
            self.init_fp8_metadata()

1333
        self.reset_parameters(defer_init=device == "meta")
1334

1335
1336
1337
1338
1339
1340
1341
        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.apply_bias:
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

1342
1343
1344
1345
1346
        if self.wgrad_store.delay_wgrad_compute():
            for name, param in self.named_parameters():
                if name in self.weight_names or name in self.bias_names:
                    param.skip_backward_post_hook = True

1347
1348
1349
1350
1351
1352
1353
1354
    def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
        """Init scales and amaxes for fwd | bwd."""
        super().set_meta_tensor(fwd, recipe)

        # customize quantizers based on each recipe & layer configs
        recipe = FP8GlobalStateManager.get_fp8_recipe()
        if recipe.float8_current_scaling():
            self._customize_quantizers_float8_current_scaling(fwd, recipe)
1355
1356
        elif recipe.float8_block_scaling():
            self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
1357
1358
        elif recipe.nvfp4():
            self._customize_quantizers_nvfp4(fwd, recipe)
1359
1360
        # elif for other recipes (mxfp8, etc.)

1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
    def reset_parameters(self, defer_init=False):
        super().reset_parameters(defer_init=defer_init)

        if not defer_init:
            # Set parallelism attributes for linear weights
            for weight in self.weight_names:
                set_tensor_model_parallel_attributes(
                    tensor=getattr(self, weight),
                    is_parallel=True,
                    dim=1 if self.parallel_mode == "row" else 0,
                    stride=1,
                )

            # Set parallelism attributes for linear biases
            if self.use_bias:
                for bias in self.bias_names:
                    if self.parallel_mode == "row":
                        setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel)
                    elif self.parallel_mode == "column":
                        set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)

1382
    @no_torch_dynamo()
1383
1384
1385
1386
    def forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Optional[bool] = None,
1387
        fp8_output: Optional[bool] = False,
1388
        fp8_grad: Optional[bool] = False,
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """
1411
1412
1413
        if is_in_onnx_export_mode():
            return self.onnx_forward(inp, fp8_output)

1414
        debug = self.is_debug_iter()
1415

1416
1417
1418
1419
        if FP8GlobalStateManager.fp8_graph_capturing():
            skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
        else:
            skip_fp8_weight_update = None
1420
1421
1422
        if skip_fp8_weight_update is not None:
            is_first_microbatch = False

1423
        if self.ub_overlap_rs_fprop:
1424
1425
1426
            if get_ub(
                self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()
            ).is_fp8_ubuf():
1427
1428
                fp8_output = True
        if self.ub_overlap_rs_dgrad:
1429
1430
1431
            if get_ub(
                self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()
            ).is_fp8_ubuf():
1432
1433
                fp8_grad = True

1434
1435
1436
        with torch.cuda.device(
            getattr(self, list(self.named_parameters())[0][0]).device
        ), self.prepare_forward(
1437
            inp,
1438
            allow_non_contiguous=isinstance(inp, QuantizedTensor),
1439
        ) as inp:
1440

1441
            weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
1442

1443
1444
1445
1446
1447
1448
            quantizers = (
                self._get_quantizers(fp8_output, fp8_grad)
                if not debug
                else self._get_debug_quantizers(fp8_output, fp8_grad)
            )
            if debug:
1449
                if self.no_debug_features_active(quantizers):
1450
                    debug = False
1451
                    quantizers = self._get_quantizers(fp8_output, fp8_grad)
1452

1453
1454
1455
1456
1457
            (
                input_quantizer,
                weight_quantizer,
                output_quantizer,
                grad_input_quantizer,
1458
1459
1460
                grad_weight_quantizer,
                grad_output_quantizer,
            ) = quantizers
1461

1462
1463
1464
1465
1466
1467
1468
1469
1470
            if torch.is_grad_enabled():
                linear_fn = _Linear.apply
                args = []
            else:
                linear_fn = _Linear.forward
                args = [None]
            args += (
                weight_tensor,
                inp,
1471
                bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
1472
1473
1474
                is_first_microbatch,
                self.fp8,
                self.fp8_calibration,
1475
                self.wgrad_store,
1476
1477
1478
1479
                input_quantizer,
                weight_quantizer,
                output_quantizer,
                grad_input_quantizer,
1480
1481
                grad_weight_quantizer,
                grad_output_quantizer,
1482
                self.fuse_wgrad_accumulation,
1483
                is_cpu_offload_enabled(),
1484
1485
1486
1487
1488
1489
1490
                self.tp_group,
                self.tp_size,
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
                torch.is_grad_enabled(),
1491
1492
1493
1494
1495
1496
                self.ub_overlap_rs_fprop,
                self.ub_overlap_ag_dgrad,
                self.ub_overlap_ag_fprop,
                self.ub_overlap_rs_dgrad,
                self.ub_bulk_dgrad,
                self.ub_bulk_wgrad,
1497
                self.ub_name,
1498
                fp8_output,
1499
                self.fsdp_group,
1500
1501
                self,
                skip_fp8_weight_update,
1502
                self.symmetric_ar_type,
1503
                self.save_original_input,
1504
                debug,
1505
1506
1507
1508
1509
1510
1511
1512
            )
            out = linear_fn(*args)
        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        return out
1513
1514
1515

    def _get_quantizers(self, fp8_output, fp8_grad):
        if not self.fp8:
1516
            return [None] * 6
1517
        grad_input_quantizer = None
1518
        grad_weight_quantizer = None
1519
1520
1521
        grad_output_quantizer = None
        output_quantizer = None
        input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
1522
        input_quantizer.internal = True
1523
        (weight_quantizer,) = self._get_weight_quantizers()
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
        if fp8_output:
            output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
        if torch.is_grad_enabled():
            grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
            grad_output_quantizer.internal = True
            if fp8_grad:
                grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
        return (
            input_quantizer,
            weight_quantizer,
            output_quantizer,
            grad_input_quantizer,
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
            grad_weight_quantizer,
            grad_output_quantizer,
        )

    def _get_debug_quantizers(self, fp8_output, fp8_grad):
        original_quantizers = self._get_quantizers(fp8_output, fp8_grad)
        assert TEDebugState.debug_enabled
        from ...debug.pytorch.debug_quantization import DebugQuantizer

        names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
        return tuple(
            DebugQuantizer(self.name, name, q, self.tp_group)
            for name, q in zip(names, original_quantizers)
1549
        )
1550

1551
    def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
        """Get the weight tensors of the module."""
        unfused_weights = [getattr(self, name) for name in self.weight_names]
        if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
            if self.fp8:
                if len(unfused_weights) != 1:
                    raise RuntimeError(
                        "Splitting QuantizedTensor into multiple params is not supported"
                    )
            else:
                warnings.warn(
                    "You are using quantized weights without quantized compute. "
                    "Please make sure this is intentional."
                )
                unfused_weights = [w.dequantize() for w in unfused_weights]
        return unfused_weights

    def _get_weight_and_bias_tensors(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        # Get concatenated weight and bias tensors
        unfused_weights = self._get_weight_tensors()
        if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
            if self.fp8:
                if len(unfused_weights) != 1:
                    raise RuntimeError(
                        "Splitting QuantizedTensor into multiple params is not supported"
                    )
            else:
                warnings.warn(
                    "You are using quantized weights without quantized compute. "
                    "Please make sure this is intentional."
                )
                unfused_weights = [w.dequantize() for w in unfused_weights]

        weight_tensor = noop_cat(unfused_weights)
        if self.use_bias:
            bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
        else:
            bias_tensor = None

        return weight_tensor, bias_tensor

    def onnx_forward(
        self,
        inp: torch.Tensor,
        fp8_output: bool,
    ) -> torch.Tensor:
        """
        ONNX-compatible version of the forward function that provides numerical equivalence
        while only using operations that have defined ONNX symbolic translations.
        This simplified implementation is designed specifically for inference scenarios.
        """
        from ..export import onnx_gemm

        assert_warmed_up(self)
        assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export."
        weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
        (
            input_quantizer,
            weight_quantizer,
            output_quantizer,
            *_,
        ) = self._get_quantizers(fp8_output, False)
        inp_dtype = inp.dtype

        if input_quantizer is not None:
            inp_q = input_quantizer.onnx_quantize(inp)
            inp = input_quantizer.onnx_dequantize(inp_q)
            inp = inp.to(inp_dtype)

        if weight_quantizer is not None:
            weight_q = weight_quantizer.onnx_quantize(weight_tensor)
            weight_tensor = weight_quantizer.onnx_dequantize(weight_q)
        if bias_tensor is not None:
            bias_tensor = bias_tensor.to(inp_dtype)
        weight_tensor = weight_tensor.to(inp_dtype)

        if self.apply_bias:
            output = onnx_gemm(weight_tensor, inp, bias_tensor)
        else:
            output = onnx_gemm(weight_tensor, inp, None)

        if output_quantizer is not None:
            raise NotImplementedError("ONNX export of quantized output is not supported")

        if self.return_bias:
            return output, bias_tensor

        return output

1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
    def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
        """Customize quantizers based on current scaling recipe + linear."""
        assert (
            recipe.float8_current_scaling()
        ), "current scaling recipe quantizer customization here"
        if fwd:
            # set configs about amax epsilon and power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_INPUT
            ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_INPUT
            ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon
            # also set weight quantizer with same amax_epsilon & power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_WEIGHT
            ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_WEIGHT
            ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
            # paralle related
            if self.sequence_parallel and self.parallel_mode == "column":
                # customize input_quantizer with amax reduction TP group
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].with_amax_reduction = True
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].amax_reduction_group = self.tp_group
        else:
            # set grad_output_quantizer with amax epsilon and power_2_scale
            self.quantizers["scaling_bwd"][
                tex.FP8BwdTensors.GRAD_OUTPUT1
            ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
            self.quantizers["scaling_bwd"][
                tex.FP8BwdTensors.GRAD_OUTPUT1
            ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
            # parallel related
            if self.sequence_parallel and self.parallel_mode == "row":
                # customize grad_output_quantizer with amax reduction TP group
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].with_amax_reduction = True
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].amax_reduction_group = self.tp_group
1686

1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
    def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None:
        """Customize quantizers based on current scaling recipe + linear."""
        assert recipe.nvfp4(), "Incorrect recipe."
        if fwd:
            if self.sequence_parallel and self.parallel_mode == "column":
                # customize input_quantizer with amax reduction TP group
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].with_amax_reduction = True
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].amax_reduction_group = self.tp_group
        else:
            if self.sequence_parallel and self.parallel_mode == "row":
                # customize grad_output_quantizer with amax reduction TP group
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].with_amax_reduction = True
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].amax_reduction_group = self.tp_group

1709
1710
    def _get_weight_quantizers(self) -> List[Quantizer]:
        """Get the weight quantizers of the module."""
1711
        if not self.fp8 and not self.fp8_calibration:
1712
1713
1714
1715
            return [None]
        weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
        weight_quantizer.internal = True
        return [weight_quantizer]
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734

    def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
        """Customize quantizers based on blockwise scaling recipe + linear."""
        assert (
            recipe.float8_block_scaling()
        ), "blockwise scaling recipe quantizer customization here"

        if fwd:
            if self.sequence_parallel and self.parallel_mode == "column":
                # set compact for inp tensor X
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].all_gather_usage = True
        else:
            if self.sequence_parallel and self.parallel_mode == "row":
                # set compact for grad_output tensor dY
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].all_gather_usage = True