linear.py 70.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Linear API"""
6
from typing import Callable, Dict, Optional, Tuple, Union, List
7
8
from functools import reduce
from operator import mul as multiply_op
9
import warnings
10
import os
11
12
13

import torch

14
import transformer_engine_torch as tex
15

16
from transformer_engine.common.recipe import Recipe
17
from transformer_engine.pytorch import torch_version
18

19
from .base import (
20
21
    fill_userbuffers_buffer_for_all_gather,
    get_dummy_wgrad,
22
    get_ub,
23
    get_workspace,
24
25
26
27
28
    TransformerEngineBaseModule,
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
29
from ._common import noop_cat, WeightGradStore
30
from ..fp8 import FP8GlobalStateManager
31
32
from ..utils import (
    cast_if_needed,
33
    clear_tensor_data,
34
    divide,
35
    init_method_constant,
36
37
    requires_grad,
    needs_quantized_gemm,
38
    assert_dim_for_fp8_exec,
39
40
    nvtx_range_pop,
    nvtx_range_push,
41
42
43
44
45
)
from ..distributed import (
    set_tensor_model_parallel_attributes,
    get_distributed_world_size,
    allreduce,
46
    symmetric_all_reduce,
47
48
    reduce_scatter_along_first_dim,
    gather_along_first_dim,
49
    is_fp8_activation_recompute_enabled,
50
    in_fp8_activation_recompute_phase,
51
52
    _fsdp_scatter_tensors,
    _fsdp_gather_tensors,
53
54
)
from ..cpp_extensions import (
55
    general_gemm,
56
)
57
from ..constants import GemmParallelModes, dist_group_type
58
from ..jit import no_torch_dynamo
59
from ..graph import is_graph_capturing
60
61
from ..tensor.quantized_tensor import (
    QuantizedTensor,
62
    QuantizedTensorBase,
63
64
65
66
    Quantizer,
    prepare_for_saving,
    restore_from_saved,
)
67
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
68
from ..tensor.mxfp8_tensor import MXFP8Quantizer
69
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
70
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
71
from ..export import is_in_onnx_export_mode, assert_warmed_up
72
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
73
74
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
75

76
77
78
79
80
81
82
83
84
85
86
__all__ = ["Linear"]


class _Linear(torch.autograd.Function):
    """Linear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
87
        weight: torch.Tensor,
88
        inp: torch.Tensor,
89
        bias: Optional[torch.Tensor],
90
91
92
        is_first_microbatch: Union[bool, None],
        fp8: bool,
        fp8_calibration: bool,
93
        wgrad_store: WeightGradStore,
94
95
96
97
        input_quantizer: Optional[Quantizer],
        weight_quantizer: Optional[Quantizer],
        output_quantizer: Optional[Quantizer],
        grad_input_quantizer: Optional[Quantizer],
98
99
        grad_weight_quantizer: Optional[Quantizer],
        grad_output_quantizer: Optional[Quantizer],
100
        fuse_wgrad_accumulation: bool,
101
        cpu_offloading: bool,
102
103
104
105
106
107
108
        tp_group: Union[dist_group_type, None],
        tp_size: int,
        sequence_parallel: bool,
        tensor_parallel: bool,
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
        is_grad_enabled: bool,
109
110
111
112
113
114
        ub_overlap_rs_fprop: bool,
        ub_overlap_ag_dgrad: bool,
        ub_overlap_ag_fprop: bool,
        ub_overlap_rs_dgrad: bool,
        ub_bulk_dgrad: bool,
        ub_bulk_wgrad: bool,
115
        ub_name: str,
116
        fp8_output: bool,  # pylint: disable=unused-argument
117
        fsdp_group: Union[dist_group_type, None],
118
119
        module: torch.nn.Module,
        skip_fp8_weight_update: bool,
120
        symmetric_ar_type: str,
121
        save_original_input: bool = False,
122
        debug: Optional[bool] = False,
123
    ) -> torch.Tensor:
124
        # pylint: disable=missing-function-docstring
125

126
127
128
129
130
        # NVTX label for profiling
        nvtx_label = "transformer_engine._Linear.forward"
        if ub_name is not None:
            nvtx_label = f"{nvtx_label}.{ub_name}"

131
        # Make sure input dimensions are compatible
132
        out_features, in_features = weight.shape
133
        assert inp.shape[-1] == in_features, "GEMM not possible"
134

135
        # Configure tensor-parallel communication
136
        tp_world_size = get_distributed_world_size(tp_group)
137
138
139
140
        backward_needs_input = is_grad_enabled and weight.requires_grad
        with_input_all_gather_nccl = (
            parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
        )
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159

        # Configure Userbuffers communication (comm+GEMM overlap)
        ub_obj = None
        ub_type = None
        if ub_overlap_rs_fprop:
            ub_obj = get_ub(ub_name + "_fprop")
            ub_type = tex.CommOverlapType.RS
        elif ub_overlap_ag_fprop:
            ub_obj = get_ub(ub_name + "_fprop")
            ub_type = tex.CommOverlapType.AG

        # ------------------------------------------------------
        # Prepare input tensor
        # Note: Cast to expected dtype and perform tensor-parallel communication
        # ------------------------------------------------------
        nvtx_range_push(f"{nvtx_label}.input_cast_comm")
        inputmat = inp  # Input tensor to save for backward (maybe sharded)
        inputmat_total = None  # Input tensor to pass to GEMM (gathered)
        own_quantized_input = False
160
        if fp8:
161
            assert_dim_for_fp8_exec(inputmat, weight)
162
163
164
165
166
            if save_original_input:
                assert not isinstance(
                    input_quantizer, Float8Quantizer
                ), "DelayedScaling recipe is not supported with save_original_input"

167
168
169
170
171
172
        if with_input_all_gather_nccl or ub_overlap_ag_fprop:  # All-gather input tensor

            # Cast local input tensor if needed
            if fp8 or debug:
                if input_quantizer is None:
                    raise ValueError("Missing quantizer for input tensor")
173
                if not isinstance(inputmat, QuantizedTensorBase):
174
175
176
                    input_quantizer.set_usage(
                        rowwise=True, columnwise=backward_needs_input and not save_original_input
                    )
177
178
179
180
181
182
183
                    if isinstance(
                        input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
                    ):
                        # All-gather is not supported with FP8 column-wise data
                        input_quantizer.set_usage(columnwise=False)
                    inputmat = input_quantizer(inputmat)
                    own_quantized_input = True
184
            else:
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
                inputmat = cast_if_needed(inp, activation_dtype)  # Cast for AMP

            # Initialize gathered input tensor
            quantizer = None
            if fp8 or debug:
                quantizer = input_quantizer
                quantizer.set_usage(rowwise=True, columnwise=False)
            if with_input_all_gather_nccl:  # Perform NCCL all-gather
                inputmat_total, _ = gather_along_first_dim(
                    inputmat,
                    tp_group,
                    quantizer=quantizer,
                )
            elif ub_overlap_ag_fprop:  # Initialize Userbuffers all-gather
                inputmat_total, _ = fill_userbuffers_buffer_for_all_gather(
                    ub_obj,
                    inputmat,
                    quantizer,
                    tp_group,
                )

        else:  # Do not all-gather input tensor
            if fp8 or debug:
                if isinstance(inputmat, QuantizedTensorBase):
                    inputmat.update_usage(rowwise_usage=True)
210
                else:
211
212
                    if input_quantizer is None:
                        raise ValueError("Missing quantizer for input tensor")
213
214
215
                    input_quantizer.set_usage(
                        rowwise=True, columnwise=backward_needs_input and not save_original_input
                    )
216
                    inputmat = input_quantizer(inputmat)
217
                    own_quantized_input = True
218
            else:
219
220
                inputmat = cast_if_needed(inp, activation_dtype)  # Cast for AMP
            inputmat_total = inputmat
221
        nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
222
223
224
        # ------------------------------------------------------
        # Input tensor is ready for GEMM...
        # ------------------------------------------------------
225

226
227
228
        # ------------------------------------------------------
        # Prepare weight tensor
        # ------------------------------------------------------
229
230
        weightmat = weight
        if fp8 or debug:
231
232
233
234
235
236
237
238
239
            # Configure quantizer
            if weight_quantizer is not None:
                columnwise_usage = is_grad_enabled and inp.requires_grad
                if not columnwise_usage:
                    columnwise_usage = (
                        is_fp8_activation_recompute_enabled()
                        and not in_fp8_activation_recompute_phase()
                    )
                weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
240
241

            # Get quantized weight
242
243
244
245
246
247
248
249
            update_workspace = is_first_microbatch is None or is_first_microbatch
            weightmat = module.get_weight_workspace(
                tensor=weight,
                quantizer=weight_quantizer,
                cache_name=(None if is_first_microbatch is None else "weight"),
                update_workspace=update_workspace,
                skip_update_flag=skip_fp8_weight_update,
                fsdp_group=fsdp_group,
250
                workspace_dtype=activation_dtype,
251
            )
252
253
            weightmat.update_usage(rowwise_usage=True)

254
        else:
255
256
257
258
            weightmat = cast_if_needed(weightmat, activation_dtype)  # Cast for AMP
        # ------------------------------------------------------
        # Weight tensor is ready for GEMM...
        # ------------------------------------------------------
259
260
261

        # Cast bias to expected dtype
        bias_dtype = activation_dtype
262
        if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32:
263
            # cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
264
265
266
267
268
269
270
271
272
273
            bias_dtype = torch.bfloat16
        bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias

        # Calibrate quantizers if needed
        if not fp8 and fp8_calibration:
            if input_quantizer is not None:
                input_quantizer.calibrate(inputmat_total)
            if weight_quantizer is not None:
                weight_quantizer.calibrate(weight)

274
275
        # Choose whether to use GEMM kernel with split accumulator
        use_split_accumulator = _2X_ACC_FPROP
276
277
278
        if fp8:
            recipe = FP8GlobalStateManager.get_fp8_recipe()
            if hasattr(recipe, "fp8_gemm_fprop"):
279
280
281
282
283
                use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator

        # Configure output quantizer
        if output_quantizer is not None:
            output_quantizer.set_usage(rowwise=True, columnwise=False)
284

285
286
287
288
289
290
291
292
293
294
295
296
297
298
        # Output buffer for Userbuffers reduce-scatter
        reduce_scatter_out = None
        if ub_overlap_rs_fprop:
            out_shape = list(inp.shape)
            out_shape[0] //= tp_world_size
            out_shape[-1] = out_features
            reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device)

        # ------------------------------------------------------
        # Forward GEMM
        # Note: y = x * w^T
        # ------------------------------------------------------
        nvtx_range_push(f"{nvtx_label}.gemm")
        gemm_out, *_, reduce_scatter_out = general_gemm(
299
300
301
302
            weightmat,
            inputmat_total,
            get_workspace(),
            quantization_params=output_quantizer,
303
            out_dtype=activation_dtype,
304
            bias=bias,
305
            use_split_accumulator=use_split_accumulator,
306
307
            ub=ub_obj,
            ub_type=ub_type,
308
            extra_output=reduce_scatter_out,
309
        )
310
        nvtx_range_pop(f"{nvtx_label}.gemm")
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        # ------------------------------------------------------
        # Finished forward GEMM...
        # ------------------------------------------------------

        # ------------------------------------------------------
        # Prepare output tensor
        # Note: Perform tensor-parallel communication
        # ------------------------------------------------------
        out = None
        if ub_overlap_rs_fprop:
            out = reduce_scatter_out
        elif parallel_mode == "row" and tp_size > 1:
            nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
            out = gemm_out
            if sequence_parallel:
                out, _ = reduce_scatter_along_first_dim(out, tp_group)
            elif tensor_parallel:
                if symmetric_ar_type is not None:
                    out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
                else:
                    out, _ = allreduce(out, tp_group)
            nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
        else:
            out = gemm_out
        # ------------------------------------------------------
        # Output tensor is ready to return...
        # ------------------------------------------------------

        # ------------------------------------------------------
        # Cache state for backward pass
        # ------------------------------------------------------
342
343

        if is_grad_enabled:
344
345
346
            if save_original_input:
                inputmat = inp

347
            ctx.weight_quantizer = weight_quantizer
348
            saved_inputmat = None
349
350
351
352
353

            ctx.backward_input_needs_gather = (
                weight.requires_grad and parallel_mode == "column" and sequence_parallel
            )

354
            if backward_needs_input:
355
356
357
358
359
360
361
362
363
364
                if not save_original_input:
                    if own_quantized_input and isinstance(inputmat, QuantizedTensorBase):
                        # For sequence parallel in vanilla FP8, rowwise data is
                        # to gather the input. For MXFP8, columnwise only data
                        # can be allgathered.
                        if (
                            isinstance(inputmat, (MXFP8TensorBase, Float8BlockwiseQTensorBase))
                            or not ctx.backward_input_needs_gather
                        ):
                            inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
365
                saved_inputmat = inputmat
366

367
368
            # Weight with column-wise usage is needed for dgrad GEMM.
            if inp.requires_grad:
369
                if isinstance(weightmat, QuantizedTensorBase):
370
371
                    weightmat.update_usage(columnwise_usage=True)

372
373
            if cpu_offloading and saved_inputmat is not None:
                mark_activation_offload(saved_inputmat)
374

375
376
            # Scatter intermediate/activation tensors saved for the backward pass
            # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights
377
            nvtx_range_push(f"{nvtx_label}.fsdp_scatter")
378
379
380
            ctx.fsdp_group = fsdp_group
            ctx.fsdp_shapes = _fsdp_scatter_tensors(
                fsdp_group,
381
                saved_inputmat,
382
                weightmat if fp8 and not isinstance(weight, QuantizedTensorBase) else None,
383
            )
384
            nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
385

evt_fugx1's avatar
evt_fugx1 committed
386
            if cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
387
388
389
390
391
392
393
394
395
396
                ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")

                if ctx.grad_added_to_main_grad:
                    # If you are passing torch.nn.Parameter through the Torch hooks, you will
                    # get back torch.Tensor. Torch rips off the Parameter wrapper.
                    # You need to preserve the weight object to have all the attributes user
                    # sets for the weights. Because of this, it is not recommended to offload
                    # weights if weights are externally touched outside this module
                    ctx.weight_object = weight

397
398
            # TODO(ksivamani): Check memory usage
            tensors_to_save, tensor_objects = prepare_for_saving(
399
                saved_inputmat,
400
                weightmat,
401
                weight,
402
                bias,
403
            )
404
405
            ctx.save_for_backward(*tensors_to_save)
            ctx.tensor_objects = tensor_objects
406

407
408
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
409
            ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
410
411
            ctx.input_quantizer = input_quantizer
            ctx.grad_input_quantizer = grad_input_quantizer
412
413
            ctx.grad_weight_quantizer = grad_weight_quantizer
            ctx.grad_output_quantizer = grad_output_quantizer
414
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
415
            if fuse_wgrad_accumulation and weight.requires_grad:
416
417
418
419
420
421
422
423
                # This check is needed to ensure that main_grad is not created
                # during the forward pass when using MCore FSDP as it creates
                # the main_grad buffer lazily before backprop
                if hasattr(weight, "__fsdp_param__"):
                    # MCore FSDP creates main_grad lazily before backward
                    ctx.main_grad_func = weight.get_main_grad
                else:
                    ctx.main_grad_func = lambda: weight.main_grad
424

425
            ctx.debug = debug
426
            ctx.cpu_offloading = cpu_offloading
427
            ctx.is_first_microbatch = is_first_microbatch
428
            ctx.use_bias = bias is not None
429
430
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
431
            ctx.inp_shape = inp.shape
432
433
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
434
435
436
437
            ctx.ub_overlap_ag = ub_overlap_ag_dgrad
            ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
            ctx.ub_bulk_dgrad = ub_bulk_dgrad
            ctx.ub_bulk_wgrad = ub_bulk_wgrad
438
            ctx.ub_name = ub_name
439
440
            ctx.tp_size = tp_size
            ctx.requires_dgrad = inp.requires_grad
441
            ctx.requires_wgrad = weight.requires_grad
442
            ctx.reduce_and_update_bwd_fp8_tensors = False
443

444
            ctx.owns_input = saved_inputmat is not inp
445
            if ctx.fp8 and requires_grad(inp, weight, bias):
446
447
448
449
                _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
                ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
                if in_fp8_activation_recompute_phase():
                    FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
450
            ctx.wgrad_store = wgrad_store
451

452
453
454
        # ------------------------------------------------------
        # Cached state for backward pass is ready...
        # ------------------------------------------------------
455

456
        return out
457
458

    @staticmethod
459
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
460
        # pylint: disable=missing-function-docstring
461

462
463
464
465
466
        # NVTX label for profiling
        nvtx_label = "transformer_engine._Linear.backward"
        if ctx.ub_name is not None:
            nvtx_label = f"{nvtx_label}.{ctx.ub_name}"

467
        with torch.cuda.nvtx.range("_Linear_backward"):
468
469
470
471
            saved_tensors = ctx.saved_tensors
            inputmat, weight_fp8, weight, bias = (  # pylint: disable=unbalanced-tuple-unpacking
                restore_from_saved(ctx.tensor_objects, saved_tensors)
            )
472

473
474
475
            # Delete the references to tensor objects once they've been consumed
            # by the `restore_from_saved` method to construct back the actual tensors.
            ctx.tensor_objects = None
476
477
478

            # Since main_grad can be modified inplace, it should not be a part of saved_tensors
            main_grad = (
479
                ctx.main_grad_func()
480
481
482
483
                if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad
                else None
            )

evt_fugx1's avatar
evt_fugx1 committed
484
            if ctx.cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
485
486
487
488
                if ctx.grad_added_to_main_grad:
                    weight = ctx.weight_object
                if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
                    weight.main_grad = main_grad
489

490
491
492
            # Gather intermediate/activation tensors if needed
            # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
            #       shards/unshards the base weights so we don't do it ourselves
493
            nvtx_range_push(f"{nvtx_label}.fsdp_gather")
494
495
496
497
            _fsdp_gather_tensors(
                ctx.fsdp_group,
                ctx.fsdp_shapes,
                inputmat,
498
                weight_fp8,
499
            )
500
            nvtx_range_pop(f"{nvtx_label}.fsdp_gather")
501

502
            # Configure Userbuffers communication (comm+GEMM overlap)
503
            ctx.ub_obj_gradout = None
504
            ub_obj_dgrad = None
505
            ub_obj_wgrad = None
506
507
            ub_type_dgrad = None
            ub_type_wgrad = None
508
            dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
509
            if ctx.ub_overlap_ag:
510
                # Overlap grad_output all-gather with dgrad compute
511
                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
512
513
                ub_obj_dgrad = ctx.ub_obj_gradout
                ub_type_dgrad = tex.CommOverlapType.AG
514
515
516
            elif ctx.ub_overlap_rs_dgrad:
                # Overlap dgrad reduce-scatter with dgrad compute
                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
517
518
                ub_obj_dgrad = ctx.ub_obj_gradout
                ub_type_dgrad = tex.CommOverlapType.RS
519
520
521
522
            else:
                if ctx.ub_bulk_dgrad:
                    # Overlap inputmat all-gather with dgrad compute
                    ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
523
524
                    ub_obj_dgrad = ctx.ub_obj_gradout
                    ub_type_dgrad = tex.CommOverlapType.AG
525
526
527
                if ctx.ub_bulk_wgrad:
                    # Overlap dgrad reduce-scatter with wgrad compute
                    ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
528
                    ub_type_wgrad = tex.CommOverlapType.RS
529
530
531
532
533
534
535
536

            # --------------------------------------------------
            # Prepare grad output tensor
            # Note: Cast to expected dtype and perform tensor-parallel communication
            # --------------------------------------------------

            # Unmodified grad output tensor
            grad_output_arg = grad_output
537

538
539
540
541
            # Configure quantizer for grad output tensor
            # Note: dgrad GEMM requires row-wise usage, wgrad GEMM
            # requires column-wise usage
            if ctx.grad_output_quantizer is not None:
542
543
544
545
546
547
548
                quantizer = ctx.grad_output_quantizer
                quantizer.set_usage(rowwise=True, columnwise=True)
                if ctx.ub_overlap_ag:
                    # Userbuffers only supports communication for one
                    # tensor usage at a time. Configure quantizer with
                    # usage for only dgrad GEMM.
                    quantizer.set_usage(columnwise=False)
549

550
            # Prepare grad output tensor
551
            nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
552
553
554
555
            (
                grad_output,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
556
557
558
559
                ctx,
                grad_output,
                ctx.parallel_mode == "row",
                ctx.grad_output_quantizer,
560
            )
561
            nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
562

563
564
565
566
567
568
569
570
571
572
            # --------------------------------------------------
            # Grad output tensor is ready for computing grad input...
            # --------------------------------------------------

            # --------------------------------------------------
            # Prepare input tensor
            # Note: Input tensor is needed for wgrad GEMM.
            # Tensor-parallel communication is overlapped with dgrad
            # GEMM.
            # --------------------------------------------------
573
            inputmat_total = None
574
            inputmat_total_work = None
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
            if ctx.requires_wgrad:
                input_is_quantized = isinstance(inputmat, QuantizedTensorBase)
                if ctx.fp8 or ctx.debug:
                    if not input_is_quantized:
                        quantizer = ctx.input_quantizer
                        if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
                            quantizer.set_usage(
                                rowwise=True,
                                columnwise=not ctx.backward_input_needs_gather,
                            )
                        else:
                            quantizer.set_usage(rowwise=False, columnwise=True)
                        inputmat = quantizer(inputmat)
                else:
                    if input_is_quantized:
                        inputmat = inputmat.dequantize(dtype=ctx.activation_dtype)
                    else:
                        inputmat = cast_if_needed(inputmat, ctx.activation_dtype)
593
            if ctx.backward_input_needs_gather:
594
                quantizer = None
595
                if ctx.fp8 or ctx.debug:
596
                    quantizer = ctx.input_quantizer
597
598
599
600
601
602
                    if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
                        # If data is in FP8, we compute FP8 transposes manually
                        quantizer.set_usage(rowwise=True, columnwise=False)
                    else:
                        # wgrad GEMM requires input with column-wise usage
                        quantizer.set_usage(rowwise=False, columnwise=True)
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
                if ctx.ub_bulk_dgrad:
                    inputmat_total, _ = fill_userbuffers_buffer_for_all_gather(
                        ub_obj_dgrad,
                        inputmat,
                        quantizer,
                        ctx.tp_group,
                    )
                else:
                    nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
                    inputmat_total, inputmat_total_work = gather_along_first_dim(
                        inputmat,
                        ctx.tp_group,
                        async_op=True,
                        quantizer=quantizer,
                    )
                    nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
619
620
            else:
                inputmat_total = inputmat
621
622
623
            # --------------------------------------------------
            # Input tensor is ready for computing grad weight...
            # --------------------------------------------------
624

625
            # --------------------------------------------------
626
            # Compute grad input tensor
627
628
            # --------------------------------------------------

629
630
            dgrad = None
            dgrad_work = None
631
            if ctx.requires_dgrad:
632

633
634
635
636
637
638
639
640
                # Make sure required data is available
                if isinstance(grad_output, QuantizedTensorBase):
                    grad_output.update_usage(rowwise_usage=True)
                if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase):
                    weight_fp8.update_usage(columnwise_usage=True)

                # Choose whether to use GEMM kernel with split accumulator
                use_split_accumulator = _2X_ACC_DGRAD
641
642
643
                if ctx.fp8:
                    recipe = ctx.fp8_recipe
                    if hasattr(recipe, "fp8_gemm_dgrad"):
644
                        use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
645

646
647
648
649
650
651
652
653
654
655
                # Update grad input quantizer
                if ctx.grad_input_quantizer is not None:
                    ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)

                # Output buffers for Userbuffers reduce-scatter
                gemm_out = None
                reduce_scatter_out = None
                if ctx.ub_overlap_rs_dgrad:
                    reduce_scatter_out = torch.empty(
                        dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device
656
                    )
657
658
                elif ctx.ub_bulk_wgrad:
                    gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False)
659

660
661
662
663
                # dgrad GEMM
                # Note: dx = dy * w
                nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
                gemm_out, *_, reduce_scatter_out = general_gemm(
664
665
666
667
668
669
                    weight_fp8,
                    grad_output,
                    get_workspace(),
                    layout="NN",
                    grad=True,
                    quantization_params=ctx.grad_input_quantizer,
670
                    out=gemm_out,
671
                    out_dtype=ctx.activation_dtype,
672
                    use_split_accumulator=use_split_accumulator,
673
674
                    ub=ub_obj_dgrad,
                    ub_type=ub_type_dgrad,
675
                    extra_output=reduce_scatter_out,
676
677
                    bulk_overlap=ctx.ub_bulk_dgrad,
                )
678
                nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
679

680
681
                # Prepare grad input tensor
                # Note: Perform tensor-parallel communication
682
                if ctx.ub_overlap_rs_dgrad:
683
684
685
686
                    dgrad = reduce_scatter_out
                elif ctx.ub_bulk_wgrad:
                    dgrad = ub_obj_wgrad.get_buffer(local_chunk=True)
                elif ctx.parallel_mode == "column" and ctx.tp_size > 1:
687
                    nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
688
                    dgrad = gemm_out
689
690
691
692
693
                    if ctx.sequence_parallel:
                        dgrad, dgrad_work = reduce_scatter_along_first_dim(
                            dgrad,
                            ctx.tp_group,
                            async_op=True,
694
                        )
695
                    else:
696
                        dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
697
                    nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
698
699
700
701
702
703
704
705
706
707
                else:
                    dgrad = gemm_out

            # --------------------------------------------------
            # Grad input tensor has been computed...
            # --------------------------------------------------

            # --------------------------------------------------
            # Compute grad weight
            # --------------------------------------------------
708

709
710
            wgrad = None
            if ctx.requires_wgrad:
711

712
713
714
                # Prepare input tensor
                # Note: Synchronize tensor-parallel communication and
                # make sure required data is available
715
716
717
                if inputmat_total_work is not None:
                    inputmat_total_work.wait()
                    inputmat_total_work = None
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
                if ctx.fp8 or ctx.debug:
                    if isinstance(inputmat_total, QuantizedTensorBase):
                        inputmat_total.update_usage(columnwise_usage=True)
                    else:
                        ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
                        inputmat_total = ctx.input_quantizer(inputmat_total)

                # Prepare grad output tensor
                # Note: Synchronize tensor-parallel communication and
                # make sure required data is available
                if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
                    # UB does not support overlapping grad output
                    # all-gather with wgrad GEMM. Also, we can't
                    # convert row-scaled MXFP8 to column-scaled, so we
                    # can't reuse the grad output that was gathered
733
734
                    # for the dgrad GEMM. We work around by explicitly
                    # overlapping the NCCL operation with the dgrad GEMM.
735
                    ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
736
737
738
739
740
741
742
743
744
745
746
747
748
749
                    # Get the communication stream from the dgrad GEMM and set it as the current torch stream
                    dgrad_comm_stream = ub_obj_dgrad.get_communication_stream()
                    with torch.cuda.stream(dgrad_comm_stream):
                        # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
                        # This ensures that we don't start until all communication for the dgrad GEMM is complete
                        grad_output, grad_output_work = gather_along_first_dim(
                            grad_output_arg,
                            ctx.tp_group,
                            async_op=True,
                            quantizer=ctx.grad_output_quantizer,
                        )
                    # Synchronize with the main stream
                    grad_output_work.wait()

750
751
752
753
754
755
                if ctx.fp8 or ctx.debug:
                    if isinstance(grad_output, QuantizedTensorBase):
                        grad_output.update_usage(columnwise_usage=True)
                    else:
                        ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
                        grad_output = ctx.grad_output_quantizer(grad_output)
756

757
758
759
760
761
762
763
                # Figure out whether to use split accumulator
                use_split_accumulator = _2X_ACC_WGRAD
                if ctx.fp8:
                    recipe = ctx.fp8_recipe
                    if hasattr(recipe, "fp8_gemm_wgrad"):
                        use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator

764
765
766
767
768
769
770
771
772
                # Figure out whether to output wgrad GEMM directly into main grad
                if ctx.is_first_microbatch is not None:
                    accumulate_wgrad_into_param_main_grad = (
                        ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                    )
                else:
                    accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation

                # Output buffer for overlapping FP8 grad input
773
                # reduce-scatter with wgrad GEMM
774
                reduce_scatter_out = None
775
                if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
776
777
                    reduce_scatter_out = torch.empty(
                        dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device
778
779
                    )

780
781
782
783
                # Arguments to include in wgrad GEMM closure
                wgrad_gemm_kwargs = {
                    "workspace": get_workspace(),
                    "out_dtype": (
784
785
                        main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
                    ),
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
                    "quantization_params": ctx.grad_weight_quantizer,
                    "accumulate": accumulate_wgrad_into_param_main_grad,
                    "layout": "NT",
                    "out": main_grad if ctx.fuse_wgrad_accumulation else None,
                    "bias": (bias if (grad_bias is None and not ctx.fp8) else None),
                    "use_split_accumulator": use_split_accumulator,
                    "grad": True,
                    "ub": ub_obj_wgrad,
                    "ub_type": ub_type_wgrad,
                    "extra_output": reduce_scatter_out,
                    "bulk_overlap": ctx.ub_bulk_wgrad,
                }

                def wgrad_gemm(
                    x: torch.Tensor,
                    dy: torch.Tensor,
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
                    """Perform wgrad GEMM: dw = dy^T * x

                    May be fused with bgrad computation.

                    May be called outside of this function to enable
                    some advanced communication/compute overlapping.

                    """
                    nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
                    dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs)
                    nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
                    return dw, db

                # Choose whether to call wgrad GEMM now or delay
817
                if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
818
819
820
821
822
823
824
825
826
827
828
                    if (
                        wgrad_gemm_kwargs["ub"] is not None
                        or wgrad_gemm_kwargs["ub_type"] is not None
                        or wgrad_gemm_kwargs["extra_output"] is not None
                        or wgrad_gemm_kwargs["bulk_overlap"]
                    ):
                        raise NotImplementedError(
                            "Delayed weight grad computation is not supported "
                            "with Userbuffers (tensor-parallel communication overlapping)"
                        )
                    ctx.wgrad_store.put([inputmat_total, grad_output], wgrad_gemm)
829
830
831
                    # overlap_grad_reduce, dongcl
                    if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
                        wgrad = torch.empty(weight.size(), dtype=ctx.activation_dtype, device=weight.device)
832
833
                else:

834
835
836
837
                    # Call wgrad GEMM now
                    wgrad, grad_bias_ = wgrad_gemm(inputmat_total, grad_output)

                    # Update grad bias if needed
838
839
840
841
                    if grad_bias is None:
                        grad_bias = grad_bias_
                    del grad_bias_

842
                    # Deallocate input tensor if permitted
843
844
                    if ctx.owns_input:
                        clear_tensor_data(inputmat_total)
845

846
                # Update grad input if overlapping reduce-scatter with wgrad GEMM
847
848
                if ctx.ub_bulk_wgrad:
                    if ub_obj_wgrad.is_fp8_ubuf():
849
                        dgrad = reduce_scatter_out
850
                    else:
851
852
853
854
855
                        dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone()

            # --------------------------------------------------
            # Grad weight has been computed...
            # --------------------------------------------------
856

857
            # Don't return grad bias if not needed
858
859
860
            if not ctx.use_bias:
                grad_bias = None

861
            # Make sure all tensor-parallel communication is finished
862
863
864
865
866
867
868
869
            if inputmat_total_work is not None:
                inputmat_total_work.wait()
                inputmat_total_work = None
            if dgrad_work is not None:
                dgrad_work.wait()
                dgrad_work = None

        if ctx.requires_wgrad:
870
            # Handle custom DDP from mcore.
871
872
873
874
875
            if (
                ctx.fuse_wgrad_accumulation
                and weight is not None
                and hasattr(weight, "grad_added_to_main_grad")
            ):
876
                weight.grad_added_to_main_grad = True
877
                if getattr(weight, "zero_out_wgrad", False):
878
879
880
881
                    wgrad = get_dummy_wgrad(
                        list(weight.main_grad.shape),
                        weight.dtype,
                        zero=True,
882
                    )
883
                else:
884
885
886
                    wgrad = get_dummy_wgrad(
                        list(weight.main_grad.shape),
                        weight.dtype,
887
                    )
888
889
890
891
            elif ctx.fuse_wgrad_accumulation:
                wgrad = None
        else:
            wgrad = None
892

893
        # Update FP8 scaling factors if needed
894
        if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
895
            nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
896
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
897
            nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
898

899
        # Scatter fp8 weight buffers
900
        if ctx.fp8 and not isinstance(weight, QuantizedTensorBase):
901
            _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
902
        return (
903
            wgrad,
904
905
            dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
            grad_bias,
906
907
908
            None,  # is_first_microbatch
            None,  # fp8
            None,  # fp8_calibration
909
            None,  # wgrad_store
910
911
912
913
            None,  # input_quantizer
            None,  # weight_quantizer
            None,  # output_quantizer
            None,  # grad_input_quantizer
914
915
            None,  # grad_weight_quantizer
            None,  # grad_output_quantizer
916
917
918
919
920
921
922
923
924
            None,  # fuse_wgrad_accumulation
            None,  # cpu_offloading
            None,  # tp_group
            None,  # tp_size
            None,  # sequence_parallel
            None,  # tensor_parallel
            None,  # activation_dtype
            None,  # parallel_mode
            None,  # is_grad_enabled
925
926
927
928
929
930
            None,  # ub_overlap_rs_fprop
            None,  # ub_overlap_ag_dgrad
            None,  # ub_overlap_ag_fprop
            None,  # ub_overlap_rs_dgrad
            None,  # ub_bulk_dgrad
            None,  # ub_bulk_wgrad
931
            None,  # ub_name
932
            None,  # fp8_output
933
            None,  # fsdp_group
934
935
            None,  # module
            None,  # skip_fp8_weight_update
936
            None,  # symmetric_ar_type
937
            None,  # save_original_input
938
            None,  # debug
939
940
941
942
        )


class Linear(TransformerEngineBaseModule):
943
    """Applies a linear transformation to the incoming data :math:`y = xA^T + b`
944
945
946
947
948
949
950
951
952
953
954
955
956
957

    On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
958
    get_rng_state_tracker : Callable, default = `None`
959
                 used to get the random number generator state tracker for initializing weights.
960
961
    rng_tracker_name : str, default = `None`
                 the param passed to get_rng_state_tracker to get the specific rng tracker.
cyanguwa's avatar
cyanguwa committed
962
    parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
963
964
965
966
967
968
969
                      Configuration for splitting the weight and bias tensors along dim 0 into
                      multiple PyTorch parameters. If a list or tuple of strings is provided,
                      they are used to make the names of equally-sized parameters. If a dict
                      (preferably an OrderedDict) is provided, the keys are used as names and
                      values as split sizes along dim 0. The resulting parameters will have
                      names that end in `_weight` or `_bias`, so trailing underscores are
                      stripped from any provided names.
970
    device : Union[torch.device, str], default = "cuda"
971
          The device on which the parameters of the model will be allocated. It is the user's
972
973
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
974
975
    name: str, default = `None`
        name of the module, currently used for debugging purposes.
976
977
978
979
980
981
982
983
984
985
986
987
988

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
989
    parallel_mode : {None, 'column', 'row'}, default = `None`
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
1007
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
1008
1009
1010
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
1011
1012
1013
1014
    delay_wgrad_compute : bool, default = `False`
                         Whether or not to delay weight gradient computation. If set to `True`,
                         it's the user's responsibility to call `module.backward_dw` to compute
                         weight gradients.
1015
1016
1017
1018
1019
    symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
                   Type of symmetric memory all-reduce to use during the forward pass.
                   This can help in latency bound communication situations.
                   Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
                   is used.
1020
1021
1022
1023
1024
    save_original_input : bool, default = `False`
                       If set to `True`, always saves the original input tensor rather than the
                       cast tensor. In some scenarios, the input tensor is used by multiple modules,
                       and saving the original input tensor may reduce the memory usage.
                       Cannot work with FP8 DelayedScaling recipe.
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
1036
        rng_tracker_name: Optional[str] = None,
1037
1038
1039
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
1040
        params_dtype: Optional[torch.dtype] = None,
1041
        parallel_mode: Optional[str] = None,
cyanguwa's avatar
cyanguwa committed
1042
        parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
1043
        device: Union[torch.device, str] = "cuda",
1044
        ub_overlap_ag: bool = False,
1045
        ub_overlap_rs: bool = False,
1046
        ub_overlap_rs_dgrad: bool = False,
1047
1048
        ub_bulk_dgrad: bool = False,
        ub_bulk_wgrad: bool = False,
1049
        ub_name: Optional[str] = None,
1050
        delay_wgrad_compute: bool = False,
1051
        symmetric_ar_type: Optional[str] = None,
1052
        save_original_input: bool = False,
1053
        name: Optional[str] = None,
1054
1055
    ) -> None:
        super().__init__()
1056
1057

        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
1058
1059
1060
1061
1062
1063
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
        self.apply_bias = bias and not return_bias
1064
1065
        self.get_rng_state_tracker = get_rng_state_tracker
        self.rng_tracker_name = rng_tracker_name
1066
        self.symmetric_ar_type = symmetric_ar_type
1067
        self.save_original_input = save_original_input
1068
1069
1070
1071
        self.name = name

        if TEDebugState.debug_enabled:
            self._turn_off_unsupported_features_in_debug()  # turn off userbuffers
1072

1073
1074
        self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)

1075
1076
        if device == "meta":
            assert parameters_split is None, "Cannot split module parameters on 'meta' device."
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

1098
        # Column parallel TP overlap options
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
        self.ub_overlap_ag_fprop = (
            self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_ag
        )
        self.ub_overlap_rs_dgrad = (
            self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_rs_dgrad
        )
        self.ub_bulk_dgrad = (
            self.parallel_mode == "column"
            and self.sequence_parallel
            and ub_bulk_dgrad
            and not self.ub_overlap_rs_dgrad
        )
        self.ub_bulk_wgrad = (
            self.parallel_mode == "column"
            and self.sequence_parallel
            and ub_bulk_wgrad
            and not self.ub_overlap_rs_dgrad
        )
1117
1118

        # Row parallel TP overlap options
1119
1120
1121
1122
1123
1124
        self.ub_overlap_rs_fprop = (
            self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_rs
        )
        self.ub_overlap_ag_dgrad = (
            self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_ag
        )
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138

        if any(
            [
                self.ub_overlap_rs_fprop,
                self.ub_overlap_ag_dgrad,
                self.ub_overlap_ag_fprop,
                self.ub_overlap_rs_dgrad,
                self.ub_bulk_dgrad,
                self.ub_bulk_wgrad,
            ]
        ):
            assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized."
        self.ub_name = ub_name

1139
1140
1141
1142
1143
1144
1145
        if self.symmetric_ar_type is not None:
            assert torch_version() >= (
                2,
                7,
                0,
            ), "Torch version must be at least 2.7 to use symmetric memory"

1146
1147
1148
        # Initialize params in FP8
        with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()

1149
1150
1151
1152
1153
1154
1155
1156
        # Contiguous buffers for params
        weight_tensor = torch.empty(
            self.out_features,
            self.in_features,
            device=device,
            dtype=params_dtype,
        )
        bias_tensor = None
1157
        if self.use_bias:
1158
1159
1160
1161
1162
            bias_tensor = torch.empty(
                self.out_features,
                device=device,
                dtype=params_dtype,
            )
1163

1164
1165
1166
1167
        # Configure parameter splits
        self.weight_names = []
        self.bias_names = []
        self.parameter_split_sizes = []
1168
        if parameters_split is None:
1169
1170
1171
1172
1173
1174
            # Split into a single parameter by default
            self.weight_names = ["weight"]
            self.bias_names = ["bias"]
            self.parameter_split_sizes = [out_features]
        elif not parameters_split:
            raise ValueError("Cannot split weight buffer into 0 parameters")
cyanguwa's avatar
cyanguwa committed
1175
        elif isinstance(parameters_split, dict):
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
            # Split parameters with provided sizes
            for name, split_size in parameters_split.items():
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
        elif all(isinstance(name, str) for name in parameters_split):
            # Split parameters evenly
            split_size = out_features // len(parameters_split)
            for name in parameters_split:
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
cyanguwa's avatar
cyanguwa committed
1188
        else:
1189
            raise TypeError("Invalid configuration for parameters split")
1190

1191
1192
1193
1194
1195
1196
        # Make sure parameter splits are valid
        if sum(self.parameter_split_sizes) != out_features:
            raise ValueError(
                f"Trying to split weight buffer ({out_features=}) "
                f"with split sizes {self.parameter_split_sizes}"
            )
1197

1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
        # Adjust parameter splits for tensor-parallel distribution
        if self.parallel_mode == "column":
            for i, size in enumerate(self.parameter_split_sizes):
                if size % self.tp_size != 0:
                    raise RuntimeError(
                        f"Attempting to distribute a parameter with out_features={size} "
                        f"between {self.tp_size} tensor-parallel processes"
                    )
                self.parameter_split_sizes[i] = size // self.tp_size

1208
1209
1210
1211
1212
        # Construct weight parameters
        # Note: Register weights together so that they are adjacent to
        # each other in Linear.parameters(). This makes it more likely
        # that they will stay contiguous if the weights are
        # manipulated externally, e.g. by FSDP.
1213
1214
1215
1216
1217
1218
1219
1220
        offset = 0
        for i, split_size in enumerate(self.parameter_split_sizes):
            split_start = offset
            offset += split_size
            split_end = offset

            # Check if parameters are subviews of buffers
            is_subview = (split_start, split_end) != (0, self.out_features)
1221
            if is_subview and with_fp8_params:
1222
1223
1224
                raise RuntimeError(
                    "Splitting QuantizedTensor into multiple params is not supported"
                )
1225

1226
            # Construct weight parameter
1227
1228
1229
1230
1231
1232
1233
            self.register_parameter(
                self.weight_names[i],
                torch.nn.Parameter(weight_tensor[split_start:split_end]),
                init_fn=init_method,
                get_rng_state_tracker=get_rng_state_tracker,
                fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
            )
1234

1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
        # Construct bias parameters if needed
        if self.use_bias:
            offset = 0
            for i, split_size in enumerate(self.parameter_split_sizes):
                split_start = offset
                offset += split_size
                split_end = offset
                self.register_parameter(
                    self.bias_names[i],
                    torch.nn.Parameter(bias_tensor[split_start:split_end]),
                    init_fn=init_method_constant(0.0),
                )
        else:
            for name in self.bias_names:
                bias = torch.Tensor().to(dtype=params_dtype, device=device)
                setattr(self, name, bias)
cyanguwa's avatar
cyanguwa committed
1251

1252
        if with_fp8_params:
1253
1254
            self.init_fp8_metadata()

1255
        self.reset_parameters(defer_init=device == "meta")
1256

1257
1258
1259
1260
1261
1262
1263
        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.apply_bias:
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

1264
1265
1266
1267
1268
1269
1270
1271
    def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
        """Init scales and amaxes for fwd | bwd."""
        super().set_meta_tensor(fwd, recipe)

        # customize quantizers based on each recipe & layer configs
        recipe = FP8GlobalStateManager.get_fp8_recipe()
        if recipe.float8_current_scaling():
            self._customize_quantizers_float8_current_scaling(fwd, recipe)
1272
1273
        elif recipe.float8_block_scaling():
            self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
1274
1275
        # elif for other recipes (mxfp8, etc.)

1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
    def reset_parameters(self, defer_init=False):
        super().reset_parameters(defer_init=defer_init)

        if not defer_init:
            # Set parallelism attributes for linear weights
            for weight in self.weight_names:
                set_tensor_model_parallel_attributes(
                    tensor=getattr(self, weight),
                    is_parallel=True,
                    dim=1 if self.parallel_mode == "row" else 0,
                    stride=1,
                )

            # Set parallelism attributes for linear biases
            if self.use_bias:
                for bias in self.bias_names:
                    if self.parallel_mode == "row":
                        setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel)
                    elif self.parallel_mode == "column":
                        set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)

1297
    @no_torch_dynamo()
1298
1299
1300
1301
    def forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Optional[bool] = None,
1302
        fp8_output: Optional[bool] = False,
1303
        fp8_grad: Optional[bool] = False,
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """
1326
1327
1328
        if is_in_onnx_export_mode():
            return self.onnx_forward(inp, fp8_output)

1329
1330
1331
1332
        debug = TEDebugState.debug_enabled
        if debug:
            self._validate_name()

1333
1334
1335
1336
        if FP8GlobalStateManager.fp8_graph_capturing():
            skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
        else:
            skip_fp8_weight_update = None
1337
1338
1339
        if skip_fp8_weight_update is not None:
            is_first_microbatch = False

1340
1341
1342
1343
1344
1345
1346
        if self.ub_overlap_rs_fprop:
            if get_ub(self.ub_name + "_fprop").is_fp8_ubuf():
                fp8_output = True
        if self.ub_overlap_rs_dgrad:
            if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
                fp8_grad = True

1347
1348
        with self.prepare_forward(
            inp,
1349
            allow_non_contiguous=isinstance(inp, QuantizedTensor),
1350
        ) as inp:
1351

1352
            weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
1353

1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
            quantizers = (
                self._get_quantizers(fp8_output, fp8_grad)
                if not debug
                else self._get_debug_quantizers(fp8_output, fp8_grad)
            )
            if debug:
                if not any_feature_enabled(quantizers):
                    # If no feature is used, then run faster implementation with debug = False.
                    quantizers = self._get_quantizers(fp8_output, fp8_grad)
                    debug = False

                if isinstance(weight_tensor, QuantizedTensor):
                    raise RuntimeError("FP8 weights are not supported in debug mode.")

1368
1369
1370
1371
1372
            (
                input_quantizer,
                weight_quantizer,
                output_quantizer,
                grad_input_quantizer,
1373
1374
1375
                grad_weight_quantizer,
                grad_output_quantizer,
            ) = quantizers
1376

1377
1378
1379
1380
1381
1382
1383
1384
1385
            if torch.is_grad_enabled():
                linear_fn = _Linear.apply
                args = []
            else:
                linear_fn = _Linear.forward
                args = [None]
            args += (
                weight_tensor,
                inp,
1386
                bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
1387
1388
1389
                is_first_microbatch,
                self.fp8,
                self.fp8_calibration,
1390
                self.wgrad_store,
1391
1392
1393
1394
                input_quantizer,
                weight_quantizer,
                output_quantizer,
                grad_input_quantizer,
1395
1396
                grad_weight_quantizer,
                grad_output_quantizer,
1397
                self.fuse_wgrad_accumulation,
1398
                is_cpu_offload_enabled(),
1399
1400
1401
1402
1403
1404
1405
                self.tp_group,
                self.tp_size,
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
                torch.is_grad_enabled(),
1406
1407
1408
1409
1410
1411
                self.ub_overlap_rs_fprop,
                self.ub_overlap_ag_dgrad,
                self.ub_overlap_ag_fprop,
                self.ub_overlap_rs_dgrad,
                self.ub_bulk_dgrad,
                self.ub_bulk_wgrad,
1412
                self.ub_name,
1413
                fp8_output,
1414
                self.fsdp_group,
1415
1416
                self,
                skip_fp8_weight_update,
1417
                self.symmetric_ar_type,
1418
                self.save_original_input,
1419
                debug,
1420
1421
1422
1423
1424
1425
1426
1427
            )
            out = linear_fn(*args)
        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        return out
1428
1429
1430

    def _get_quantizers(self, fp8_output, fp8_grad):
        if not self.fp8:
1431
            return [None] * 6
1432
        grad_input_quantizer = None
1433
        grad_weight_quantizer = None
1434
1435
1436
        grad_output_quantizer = None
        output_quantizer = None
        input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
1437
        input_quantizer.internal = True
1438
        (weight_quantizer,) = self._get_weight_quantizers()
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
        if fp8_output:
            output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
        if torch.is_grad_enabled():
            grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
            grad_output_quantizer.internal = True
            if fp8_grad:
                grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
        return (
            input_quantizer,
            weight_quantizer,
            output_quantizer,
            grad_input_quantizer,
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
            grad_weight_quantizer,
            grad_output_quantizer,
        )

    def _get_debug_quantizers(self, fp8_output, fp8_grad):
        original_quantizers = self._get_quantizers(fp8_output, fp8_grad)
        assert TEDebugState.debug_enabled
        from ...debug.pytorch.debug_quantization import DebugQuantizer

        names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
        return tuple(
            DebugQuantizer(self.name, name, q, self.tp_group)
            for name, q in zip(names, original_quantizers)
1464
        )
1465

1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
    def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
        """Get the weight tensors of the module."""
        unfused_weights = [getattr(self, name) for name in self.weight_names]
        if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
            if self.fp8:
                if len(unfused_weights) != 1:
                    raise RuntimeError(
                        "Splitting QuantizedTensor into multiple params is not supported"
                    )
            else:
                warnings.warn(
                    "You are using quantized weights without quantized compute. "
                    "Please make sure this is intentional."
                )
                unfused_weights = [w.dequantize() for w in unfused_weights]
        return unfused_weights

    def _get_weight_and_bias_tensors(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        # Get concatenated weight and bias tensors
        unfused_weights = self._get_weight_tensors()
        if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
            if self.fp8:
                if len(unfused_weights) != 1:
                    raise RuntimeError(
                        "Splitting QuantizedTensor into multiple params is not supported"
                    )
            else:
                warnings.warn(
                    "You are using quantized weights without quantized compute. "
                    "Please make sure this is intentional."
                )
                unfused_weights = [w.dequantize() for w in unfused_weights]

        weight_tensor = noop_cat(unfused_weights)
        if self.use_bias:
            bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
        else:
            bias_tensor = None

        return weight_tensor, bias_tensor

    def onnx_forward(
        self,
        inp: torch.Tensor,
        fp8_output: bool,
    ) -> torch.Tensor:
        """
        ONNX-compatible version of the forward function that provides numerical equivalence
        while only using operations that have defined ONNX symbolic translations.
        This simplified implementation is designed specifically for inference scenarios.
        """
        from ..export import onnx_gemm

        assert_warmed_up(self)
        assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export."
        weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
        (
            input_quantizer,
            weight_quantizer,
            output_quantizer,
            *_,
        ) = self._get_quantizers(fp8_output, False)
        inp_dtype = inp.dtype

        if input_quantizer is not None:
            inp_q = input_quantizer.onnx_quantize(inp)
            inp = input_quantizer.onnx_dequantize(inp_q)
            inp = inp.to(inp_dtype)

        if weight_quantizer is not None:
            weight_q = weight_quantizer.onnx_quantize(weight_tensor)
            weight_tensor = weight_quantizer.onnx_dequantize(weight_q)
        if bias_tensor is not None:
            bias_tensor = bias_tensor.to(inp_dtype)
        weight_tensor = weight_tensor.to(inp_dtype)

        if self.apply_bias:
            output = onnx_gemm(weight_tensor, inp, bias_tensor)
        else:
            output = onnx_gemm(weight_tensor, inp, None)

        if output_quantizer is not None:
            raise NotImplementedError("ONNX export of quantized output is not supported")

        if self.return_bias:
            return output, bias_tensor

        return output

1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
    def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
        """Customize quantizers based on current scaling recipe + linear."""
        assert (
            recipe.float8_current_scaling()
        ), "current scaling recipe quantizer customization here"
        if fwd:
            # set configs about amax epsilon and power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_INPUT
            ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_INPUT
            ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon
            # also set weight quantizer with same amax_epsilon & power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_WEIGHT
            ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_WEIGHT
            ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
            # paralle related
            if self.sequence_parallel and self.parallel_mode == "column":
                # customize input_quantizer with amax reduction TP group
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].with_amax_reduction = True
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].amax_reduction_group = self.tp_group
        else:
            # set grad_output_quantizer with amax epsilon and power_2_scale
            self.quantizers["scaling_bwd"][
                tex.FP8BwdTensors.GRAD_OUTPUT1
            ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
            self.quantizers["scaling_bwd"][
                tex.FP8BwdTensors.GRAD_OUTPUT1
            ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
            # parallel related
            if self.sequence_parallel and self.parallel_mode == "row":
                # customize grad_output_quantizer with amax reduction TP group
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].with_amax_reduction = True
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].amax_reduction_group = self.tp_group
1601
1602
1603
1604
1605
1606
1607
1608

    def _get_weight_quantizers(self) -> List[Quantizer]:
        """Get the weight quantizers of the module."""
        if not self.fp8:
            return [None]
        weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
        weight_quantizer.internal = True
        return [weight_quantizer]
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627

    def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
        """Customize quantizers based on blockwise scaling recipe + linear."""
        assert (
            recipe.float8_block_scaling()
        ), "blockwise scaling recipe quantizer customization here"

        if fwd:
            if self.sequence_parallel and self.parallel_mode == "column":
                # set compact for inp tensor X
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].all_gather_usage = True
        else:
            if self.sequence_parallel and self.parallel_mode == "row":
                # set compact for grad_output tensor dY
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].all_gather_usage = True