"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "77c37d49f3f2354fb9edb13ae72fa01dd39ae4b5"
linear.py 57.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Linear API"""
6
from typing import Callable, Dict, Optional, Tuple, Union
7
8
from functools import reduce
from operator import mul as multiply_op
9
10
11

import torch

12
import transformer_engine_torch as tex
13

14
from transformer_engine.common.recipe import Recipe
15
16
17
18
from .base import (
    get_workspace,
    get_ub,
    TransformerEngineBaseModule,
19
    get_dummy_wgrad,
20
21
22
23
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
24
25
from ._common import noop_cat, _fix_gathered_fp8_transpose
from ..fp8 import FP8GlobalStateManager
26
27
from ..utils import (
    cast_if_needed,
28
    clear_tensor_data,
29
    divide,
30
    init_method_constant,
31
32
    requires_grad,
    needs_quantized_gemm,
33
    non_tn_fp8_gemm_supported,
34
    assert_dim_for_fp8_exec,
35
36
    nvtx_range_pop,
    nvtx_range_push,
37
38
39
40
41
42
43
)
from ..distributed import (
    set_tensor_model_parallel_attributes,
    get_distributed_world_size,
    allreduce,
    reduce_scatter_along_first_dim,
    gather_along_first_dim,
44
    is_fp8_activation_recompute_enabled,
45
    in_fp8_activation_recompute_phase,
46
47
    _fsdp_scatter_tensors,
    _fsdp_gather_tensors,
48
49
)
from ..cpp_extensions import (
50
    general_gemm,
51
)
52
from ..constants import GemmParallelModes, dist_group_type
53
from ..jit import no_torch_dynamo
54
from ..graph import is_graph_capturing
55
56
57
58
59
60
from ..tensor.quantized_tensor import (
    QuantizedTensor,
    Quantizer,
    prepare_for_saving,
    restore_from_saved,
)
61
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
62
from ..tensor.mxfp8_tensor import MXFP8Quantizer
63
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
64
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
65
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
66
67
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
68

69

70
71
72
73
74
75
76
77
78
79
80
__all__ = ["Linear"]


class _Linear(torch.autograd.Function):
    """Linear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
81
        weight: torch.Tensor,
82
        inp: torch.Tensor,
83
        bias: Optional[torch.Tensor],
84
85
86
        is_first_microbatch: Union[bool, None],
        fp8: bool,
        fp8_calibration: bool,
87
88
89
90
        input_quantizer: Optional[Quantizer],
        weight_quantizer: Optional[Quantizer],
        output_quantizer: Optional[Quantizer],
        grad_input_quantizer: Optional[Quantizer],
91
92
        grad_weight_quantizer: Optional[Quantizer],
        grad_output_quantizer: Optional[Quantizer],
93
        fuse_wgrad_accumulation: bool,
94
        cpu_offloading: bool,
95
96
97
98
99
100
101
        tp_group: Union[dist_group_type, None],
        tp_size: int,
        sequence_parallel: bool,
        tensor_parallel: bool,
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
        is_grad_enabled: bool,
102
103
104
105
106
107
        ub_overlap_rs_fprop: bool,
        ub_overlap_ag_dgrad: bool,
        ub_overlap_ag_fprop: bool,
        ub_overlap_rs_dgrad: bool,
        ub_bulk_dgrad: bool,
        ub_bulk_wgrad: bool,
108
        ub_name: str,
109
        fp8_output: bool,  # pylint: disable=unused-argument
110
        fsdp_group: Union[dist_group_type, None],
111
112
        module: torch.nn.Module,
        skip_fp8_weight_update: bool,
113
        debug: Optional[bool] = False,
114
    ) -> torch.Tensor:
115
        # pylint: disable=missing-function-docstring
116

117
118
119
120
121
        # NVTX label for profiling
        nvtx_label = "transformer_engine._Linear.forward"
        if ub_name is not None:
            nvtx_label = f"{nvtx_label}.{ub_name}"

122
        # Make sure input dimensions are compatible
123
124
125
        out_features, in_features = weight.shape
        inp_shape = inp.shape
        assert inp_shape[-1] == in_features, "GEMM not possible"
126

127
        tp_world_size = get_distributed_world_size(tp_group)
128
129
130
131
        backward_needs_input = is_grad_enabled and weight.requires_grad

        # Prepare input tensor
        # Note: Cast to expected dtype and perform tensor-parallel communication
132
        nvtx_range_push(f"{nvtx_label}.input_cast_comm")
133
        inputmat = inp.view(-1, in_features)
134
135
136
137
138
        inputmat_total = None
        with_input_all_gather_nccl = (
            parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
        )
        own_quantized_input = False
139
140
141
142
        # TODO(kwyss): Support FP8 allgather for FP8 block quantization.
        force_hp_input_gather = (
            fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer)
        )  # Perform TP communication in high precision.
143
        if fp8:
144
            assert_dim_for_fp8_exec(inputmat, weight)
145
146
            if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not (
                FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
147
148
            ):
                raise NotImplementedError(
149
150
                    "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
                    " current scaling"
151
                )
152
        if fp8 or debug:
153
154
155
            if input_quantizer is None:
                raise ValueError("Missing quantizer for input tensor")
            if with_input_all_gather_nccl:
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
                if force_hp_input_gather:
                    input_quantizer.set_usage(rowwise=True, columnwise=False)
                    inputmat_total, _ = gather_along_first_dim(
                        inputmat, tp_group, quantizer=input_quantizer
                    )
                else:
                    if not isinstance(inputmat, QuantizedTensor):
                        columnwise_usage = backward_needs_input and isinstance(
                            input_quantizer, MXFP8Quantizer
                        )
                        # force_hp_input_gather should enforce this
                        assert not isinstance(input_quantizer, Float8BlockQuantizer)
                        input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
                        inputmat = input_quantizer(inputmat)
                        own_quantized_input = True
                    input_quantizer.set_usage(rowwise=True, columnwise=False)
                    inputmat_total, _ = gather_along_first_dim(
                        inputmat,
                        tp_group,
                        quantizer=input_quantizer,
176
                    )
177
            else:
178
179
180
181
182
183
184
185
186
187
188
                if (
                    FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
                    and ub_bulk_dgrad
                ):
                    # reduce duplicated transpose in `_fix_gathered_fp8_transpose`
                    input_quantizer.set_usage(rowwise=True, columnwise=False)
                else:
                    input_quantizer.set_usage(
                        rowwise=True,
                        columnwise=backward_needs_input,
                    )
189
190
                if not isinstance(inputmat, QuantizedTensor):
                    inputmat = input_quantizer(inputmat)
191
                    own_quantized_input = True
192
193
194
                elif backward_needs_input:
                    inputmat.update_usage(rowwise_usage=True, columnwise_usage=True)
                inputmat_total = inputmat
195
        else:
196
197
198
            inputmat = cast_if_needed(inp, activation_dtype)
            if with_input_all_gather_nccl:
                inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
199
            else:
200
                inputmat_total = inputmat
201
        nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
202

203
        # Cast weight to expected dtype
204
205
206
        weightmat = weight

        if fp8 or debug:
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
            # Configure quantizer
            if weight_quantizer is not None:
                columnwise_usage = is_grad_enabled and inp.requires_grad
                if not columnwise_usage:
                    columnwise_usage = (
                        is_fp8_activation_recompute_enabled()
                        and not in_fp8_activation_recompute_phase()
                    )
                weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
            # FP8 cast to workspace buffer
            update_workspace = is_first_microbatch is None or is_first_microbatch
            weightmat = module.get_weight_workspace(
                tensor=weight,
                quantizer=weight_quantizer,
                cache_name=(None if is_first_microbatch is None else "weight"),
                update_workspace=update_workspace,
                skip_update_flag=skip_fp8_weight_update,
                fsdp_group=fsdp_group,
225
                workspace_dtype=activation_dtype,
226
            )
227
228
        else:
            weightmat = cast_if_needed(weightmat, activation_dtype)
229
230
231

        # Cast bias to expected dtype
        bias_dtype = activation_dtype
232
        if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32:
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
            bias_dtype = torch.bfloat16
        bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias

        # Configure output quantizer
        if output_quantizer is not None:
            output_quantizer.set_usage(rowwise=True, columnwise=False)

        # Calibrate quantizers if needed
        if not fp8 and fp8_calibration:
            if input_quantizer is not None:
                input_quantizer.calibrate(inputmat_total)
            if weight_quantizer is not None:
                weight_quantizer.calibrate(weight)

        ub_obj = None
        ub_type = None
        rs_out = None
        out_dtype = activation_dtype
        if ub_overlap_rs_fprop:
            ub_obj = get_ub(ub_name + "_fprop")
            ub_type = tex.CommOverlapType.RS
            out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features]
            rs_out = torch.empty(out_shape, dtype=activation_dtype, device=inputmat_total.device)

        elif ub_overlap_ag_fprop:
            ub_obj = get_ub(ub_name + "_fprop")
            ub_type = tex.CommOverlapType.AG
            if fp8:
                assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer."
            ub_obj.copy_into_buffer(inputmat_total, input_quantizer, local_chunk=True)
            inputmat_total = ub_obj.get_buffer(input_quantizer)

265
        nvtx_range_push(f"{nvtx_label}.gemm")
266
267
268
269
270
271
        fprop_gemm_use_split_accumulator = _2X_ACC_FPROP
        if fp8:
            recipe = FP8GlobalStateManager.get_fp8_recipe()
            if hasattr(recipe, "fp8_gemm_fprop"):
                fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator

272
273
274
275
276
277
278
        out, *_, rs_out = general_gemm(
            weightmat,
            inputmat_total,
            get_workspace(),
            quantization_params=output_quantizer,
            out_dtype=out_dtype,
            bias=bias,
279
            use_split_accumulator=fprop_gemm_use_split_accumulator,
280
281
282
283
            ub=ub_obj,
            ub_type=ub_type,
            extra_output=rs_out,
        )
284
        nvtx_range_pop(f"{nvtx_label}.gemm")
285
286

        if is_grad_enabled:
287
            ctx.weight_quantizer = weight_quantizer
288
            saved_inputmat = None
289
290
291
292
293

            ctx.backward_input_needs_gather = (
                weight.requires_grad and parallel_mode == "column" and sequence_parallel
            )

294
295
            if backward_needs_input:
                if own_quantized_input and isinstance(inputmat, QuantizedTensor):
296
297
298
299
                    # For sequence parallel in vanilla FP8, rowwise data is
                    # to gather the input. For MXFP8, columnwise only data
                    # can be allgathered.
                    if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather:
300
                        inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
301
302
                if force_hp_input_gather:
                    assert not isinstance(inputmat, QuantizedTensor)
303
                saved_inputmat = inputmat
304

305
306
307
308
309
            # Weight with column-wise usage is needed for dgrad GEMM.
            if inp.requires_grad:
                if isinstance(weightmat, QuantizedTensor):
                    weightmat.update_usage(columnwise_usage=True)

310
311
            if cpu_offloading and saved_inputmat is not None:
                mark_activation_offload(saved_inputmat)
312

313
314
            # Scatter intermediate/activation tensors saved for the backward pass
            # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights
315
            nvtx_range_push(f"{nvtx_label}.fsdp_scatter")
316
317
318
            ctx.fsdp_group = fsdp_group
            ctx.fsdp_shapes = _fsdp_scatter_tensors(
                fsdp_group,
319
320
                saved_inputmat,
                weightmat if fp8 and not isinstance(weight, QuantizedTensor) else None,
321
            )
322
            nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
323

324
325
326
327
328
329
330
331
332
333
334
            if cpu_offloading:
                ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")

                if ctx.grad_added_to_main_grad:
                    # If you are passing torch.nn.Parameter through the Torch hooks, you will
                    # get back torch.Tensor. Torch rips off the Parameter wrapper.
                    # You need to preserve the weight object to have all the attributes user
                    # sets for the weights. Because of this, it is not recommended to offload
                    # weights if weights are externally touched outside this module
                    ctx.weight_object = weight

335
336
            # TODO(ksivamani): Check memory usage
            tensors_to_save, tensor_objects = prepare_for_saving(
337
                saved_inputmat,
338
                weightmat,
339
                weight,
340
                bias,
341
            )
342
343
            ctx.save_for_backward(*tensors_to_save)
            ctx.tensor_objects = tensor_objects
344

345
346
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
347
348
            ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
            ctx.force_hp_input_gather = force_hp_input_gather
349
350
            ctx.input_quantizer = input_quantizer
            ctx.grad_input_quantizer = grad_input_quantizer
351
352
            ctx.grad_weight_quantizer = grad_weight_quantizer
            ctx.grad_output_quantizer = grad_output_quantizer
353
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
354
355
356
            if fuse_wgrad_accumulation and weight.requires_grad:
                ctx.main_grad = weight.main_grad

357
            ctx.debug = debug
358
            ctx.cpu_offloading = cpu_offloading
359
            ctx.is_first_microbatch = is_first_microbatch
360
            ctx.use_bias = bias is not None
361
362
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
363
            ctx.inp_shape = inp_shape
364
365
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
366
367
368
369
            ctx.ub_overlap_ag = ub_overlap_ag_dgrad
            ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
            ctx.ub_bulk_dgrad = ub_bulk_dgrad
            ctx.ub_bulk_wgrad = ub_bulk_wgrad
370
            ctx.ub_name = ub_name
371
372
            ctx.tp_size = tp_size
            ctx.requires_dgrad = inp.requires_grad
373
            ctx.requires_wgrad = weight.requires_grad
374
            ctx.reduce_and_update_bwd_fp8_tensors = False
375
            ctx.owns_input = saved_inputmat is not inp
376
            if ctx.fp8 and requires_grad(inp, weight, bias):
377
378
379
380
                _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
                ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
                if in_fp8_activation_recompute_phase():
                    FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
381
382

        # Row Parallel Linear
383
384
385
        if ub_overlap_rs_fprop:
            out = rs_out
        elif parallel_mode == "row":
386
            nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
387
            if sequence_parallel:
388
389
390
                out, _ = reduce_scatter_along_first_dim(out, tp_group)
            elif tensor_parallel:
                out, _ = allreduce(out, tp_group)
391
            nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
392

393
394
        out = out.view(-1, *inp_shape[1:-1], out_features)
        return out
395
396

    @staticmethod
397
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
398
        # pylint: disable=missing-function-docstring
399

400
401
402
403
404
        # NVTX label for profiling
        nvtx_label = "transformer_engine._Linear.backward"
        if ctx.ub_name is not None:
            nvtx_label = f"{nvtx_label}.{ctx.ub_name}"

405
        with torch.cuda.nvtx.range("_Linear_backward"):
406
407
408
409
410
411
412
413
414
415
            if (
                ctx.fp8
                and any(
                    [
                        ctx.ub_overlap_ag,
                        ctx.ub_overlap_rs_dgrad,
                        ctx.ub_bulk_dgrad,
                        ctx.ub_bulk_wgrad,
                    ]
                )
416
                and (ctx.fp8_recipe is not None)
417
            ):
418
                if not ctx.fp8_recipe.float8_per_tensor_scaling():
419
                    raise NotImplementedError(
420
421
                        "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
                        " current scaling"
422
                    )
423
424
425
426
427

            saved_tensors = ctx.saved_tensors
            inputmat, weight_fp8, weight, bias = (  # pylint: disable=unbalanced-tuple-unpacking
                restore_from_saved(ctx.tensor_objects, saved_tensors)
            )
428
429
430
            # Delete the references to tensor objects once they've been consumed
            # by the `restore_from_saved` method to construct back the actual tensors.
            ctx.tensor_objects = None
431
432
433
434
435
436
437
438

            # Since main_grad can be modified inplace, it should not be a part of saved_tensors
            main_grad = (
                ctx.main_grad
                if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad
                else None
            )

439
440
441
442
443
            if ctx.cpu_offloading:
                if ctx.grad_added_to_main_grad:
                    weight = ctx.weight_object
                if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
                    weight.main_grad = main_grad
444

445
446
447
            # Gather intermediate/activation tensors if needed
            # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
            #       shards/unshards the base weights so we don't do it ourselves
448
            nvtx_range_push(f"{nvtx_label}.fsdp_gather")
449
450
451
452
            _fsdp_gather_tensors(
                ctx.fsdp_group,
                ctx.fsdp_shapes,
                inputmat,
453
                weight_fp8,
454
            )
455
            nvtx_range_pop(f"{nvtx_label}.fsdp_gather")
456

457
            ctx.ub_obj_gradout = None
458
            ub_obj_dgrad = None
459
            ub_obj_wgrad = None
460
461
            ub_type_dgrad = None
            ub_type_wgrad = None
462
            dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
463
464
            rs_out = None
            dgrad_bulk = None
465
            if ctx.ub_overlap_ag:
466
                # Overlap grad_output all-gather with dgrad compute
467
                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
468
469
                ub_obj_dgrad = ctx.ub_obj_gradout
                ub_type_dgrad = tex.CommOverlapType.AG
470
471
472
473

            elif ctx.ub_overlap_rs_dgrad:
                # Overlap dgrad reduce-scatter with dgrad compute
                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
474
475
                ub_obj_dgrad = ctx.ub_obj_gradout
                ub_type_dgrad = tex.CommOverlapType.RS
476
477
478
479
480
481
482
                rs_out = torch.empty(
                    dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
                )

            else:
                if ctx.ub_bulk_dgrad:
                    # Overlap inputmat all-gather with dgrad compute
483
484
485
486
                    # NOTE: Copying into communication buffer will always prefer rowwise data,
                    #       and will copy columnwise data if rowwise does not exist. In that case,
                    #       the all-gather will apply to the leading dimension of the transpose,
                    #       which then needs to be interleaved correctly before WGRAD.
487
                    ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
488
489
490
                    ub_obj_dgrad = ctx.ub_obj_gradout
                    ub_type_dgrad = tex.CommOverlapType.AG
                    ub_obj_dgrad.copy_into_buffer(inputmat, ctx.input_quantizer, local_chunk=True)
491
492
493
494

                if ctx.ub_bulk_wgrad:
                    # Overlap dgrad reduce-scatter with wgrad compute
                    ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
495
496
497
498
                    ub_type_wgrad = tex.CommOverlapType.RS
                    ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
                    dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer)

499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
            # Configure quantizer for grad output tensor
            # Note: dgrad GEMM requires row-wise usage, wgrad GEMM
            # requires column-wise usage
            if ctx.grad_output_quantizer is not None:
                rowwise_usage = True
                columnwise_usage = True
                if ctx.ub_overlap_ag and isinstance(
                    ctx.grad_output_quantizer,
                    (Float8Quantizer, Float8CurrentScalingQuantizer),
                ):
                    # If data is in FP8 and communication is handled
                    # with Userbuffers, we compute FP8 transposes
                    # manually
                    columnwise_usage = False
                ctx.grad_output_quantizer.set_usage(
                    rowwise=rowwise_usage,
                    columnwise=columnwise_usage,
                )

518
519
            # Prepare grad output tensor
            # Note: Cast to expected dtype and perform tensor-parallel communication
520
            nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
521
522
523
524
            (
                grad_output,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
525
526
527
528
                ctx,
                grad_output,
                ctx.parallel_mode == "row",
                ctx.grad_output_quantizer,
529
            )
530
            nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
531

532
            # Launch tensor-parallel communication for input tensor
533
            inputmat_total = None
534
            inputmat_total_work = None
535
            if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad:
536
                quantizer = None
537
                if ctx.fp8 or ctx.debug:
538
                    quantizer = ctx.input_quantizer
539
540
541
542
543
544
                    if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
                        # If data is in FP8, we compute FP8 transposes manually
                        quantizer.set_usage(rowwise=True, columnwise=False)
                    else:
                        # wgrad GEMM requires input with column-wise usage
                        quantizer.set_usage(rowwise=False, columnwise=True)
545
                nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
546
                gather_quantizer = None if ctx.force_hp_input_gather else quantizer
547
548
549
550
                inputmat_total, inputmat_total_work = gather_along_first_dim(
                    inputmat,
                    ctx.tp_group,
                    async_op=True,
551
                    quantizer=gather_quantizer,
552
                )
553
                nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
554
555
556
            else:
                inputmat_total = inputmat

557
            # Check whether to output wgrad GEMM directly into main grad
558
559
560
561
562
563
564
            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                )
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation

565
566
567
            # Compute grad input tensor
            dgrad = None
            dgrad_work = None
568
            if ctx.requires_dgrad:
569
570
571
572
573

                # Update quantizer
                if ctx.grad_input_quantizer is not None:
                    ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
                # dgrad GEMM
574
                nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
575
576
577
578
579
580
581
582
                dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
                if ctx.fp8:
                    recipe = ctx.fp8_recipe
                    if hasattr(recipe, "fp8_gemm_dgrad"):
                        dgrad_gemm_use_split_accumulator = (
                            recipe.fp8_gemm_dgrad.use_split_accumulator
                        )

583
584
585
586
587
588
                if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensor):
                    weight_fp8.update_usage(
                        rowwise_usage=ctx.weight_quantizer.rowwise_usage,
                        columnwise_usage=ctx.weight_quantizer.columnwise_usage,
                    )

589
590
591
592
593
594
595
596
597
                dgrad, *_, rs_out = general_gemm(
                    weight_fp8,
                    grad_output,
                    get_workspace(),
                    layout="NN",
                    grad=True,
                    quantization_params=ctx.grad_input_quantizer,
                    out=dgrad_bulk,
                    out_dtype=ctx.activation_dtype,
598
                    use_split_accumulator=dgrad_gemm_use_split_accumulator,
599
600
601
602
603
                    ub=ub_obj_dgrad,
                    ub_type=ub_type_dgrad,
                    extra_output=rs_out,
                    bulk_overlap=ctx.ub_bulk_dgrad,
                )
604
                nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
605
606
607
608
609

                # Launch tensor-parallel communication
                if ctx.ub_overlap_rs_dgrad:
                    dgrad = rs_out
                elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad:
610
                    nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
611
612
613
614
615
                    if ctx.sequence_parallel:
                        dgrad, dgrad_work = reduce_scatter_along_first_dim(
                            dgrad,
                            ctx.tp_group,
                            async_op=True,
616
                        )
617
                    else:
618
                        dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
619
                    nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
620

621
622
623
            # Compute grad weight tensor
            wgrad = None
            if ctx.requires_wgrad:
624
625

                # Synchronize tensor-parallel communication for input tensor
626
627
628
629
630
631
632
633
634
635
636
637
638
                if ctx.ub_bulk_dgrad:
                    inputmat_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer)
                    if ctx.fp8:
                        if inputmat._data is None:
                            # All-gather executed on columnwise data and result is in rowwise data,
                            # so we need to fix the interleaving before WGRAD.
                            inputmat_total = _fix_gathered_fp8_transpose(
                                inputmat_total, ctx.tp_size
                            )
                        elif not non_tn_fp8_gemm_supported():
                            # FP8 GEMM on Hopper only supports TN layout so the gathered input must
                            # have a valid transpose.
                            inputmat_total._create_transpose()
639
640
641
                if inputmat_total_work is not None:
                    inputmat_total_work.wait()
                    inputmat_total_work = None
642
643
644
645
646
647
648
                if ctx.input_quantizer is not None and not isinstance(
                    inputmat_total, QuantizedTensor
                ):
                    # Async gather in BF16 does not asynchronously
                    # call quantizer after gather.
                    ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
                    inputmat_total = ctx.input_quantizer(inputmat_total)
649

650
651
652
                # Make sure GEMM inputs have required data
                if isinstance(inputmat_total, QuantizedTensor):
                    inputmat_total.update_usage(columnwise_usage=True)
653
                if isinstance(grad_output, QuantizedTensor):
654
                    grad_output.update_usage(columnwise_usage=True)
655

656
657
658
659
660
661
662
663
664
                # Figure out whether to use split accumulator
                use_split_accumulator = _2X_ACC_WGRAD
                if ctx.fp8:
                    recipe = ctx.fp8_recipe
                    if hasattr(recipe, "fp8_gemm_wgrad"):
                        use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator

                # Output buffer for overlapping grad input
                # reduce-scatter with wgrad GEMM
665
666
667
                if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
                    rs_out = torch.empty(
                        dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
668
669
                    )

670
671
                # wgrad GEMM
                # Note: Fuse with bgrad computation if needed
672
                nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
673
674
675
676
677
678
679
680
681
682
683
                wgrad, grad_bias_, _, rs_out = general_gemm(
                    inputmat_total,
                    grad_output,
                    get_workspace(),
                    layout="NT",
                    grad=True,
                    out_dtype=(
                        main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
                    ),
                    bias=(bias if (grad_bias is None and not ctx.fp8) else None),
                    out=main_grad if ctx.fuse_wgrad_accumulation else None,
684
                    use_split_accumulator=use_split_accumulator,
685
                    accumulate=accumulate_wgrad_into_param_main_grad,
686
                    quantization_params=ctx.grad_weight_quantizer,
687
688
689
690
691
                    ub=ub_obj_wgrad,
                    ub_type=ub_type_wgrad,
                    extra_output=rs_out,
                    bulk_overlap=ctx.ub_bulk_wgrad,
                )
692
                nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
693

694
695
696
                if ctx.ub_bulk_wgrad:
                    if ub_obj_wgrad.is_fp8_ubuf():
                        dgrad = rs_out
697
                    else:
698
                        dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True)
699

700
701
702
                if grad_bias is None:
                    grad_bias = grad_bias_
                del grad_bias_
703

704
                # Deallocate input tensor
705
706
                if ctx.owns_input:
                    clear_tensor_data(inputmat_total)
707

708
            # Don't return grad bias if not needed
709
710
711
            if not ctx.use_bias:
                grad_bias = None

712
            # Make sure all tensor-parallel communication is finished
713
714
715
716
717
718
719
720
            if inputmat_total_work is not None:
                inputmat_total_work.wait()
                inputmat_total_work = None
            if dgrad_work is not None:
                dgrad_work.wait()
                dgrad_work = None

        if ctx.requires_wgrad:
721
            # Handle custom DDP from mcore.
722
723
724
725
726
            if (
                ctx.fuse_wgrad_accumulation
                and weight is not None
                and hasattr(weight, "grad_added_to_main_grad")
            ):
727
                weight.grad_added_to_main_grad = True
728
                if getattr(weight, "zero_out_wgrad", False):
729
730
731
732
                    wgrad = get_dummy_wgrad(
                        list(weight.main_grad.shape),
                        weight.dtype,
                        zero=True,
733
                    )
734
                else:
735
736
737
                    wgrad = get_dummy_wgrad(
                        list(weight.main_grad.shape),
                        weight.dtype,
738
                    )
739
740
741
742
            elif ctx.fuse_wgrad_accumulation:
                wgrad = None
        else:
            wgrad = None
743

744
        if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
745
            nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
746
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
747
            nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
748

749
        # Scatter fp8 weight buffers
750
        if ctx.fp8 and not isinstance(weight, QuantizedTensor):
751
            _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
752
        return (
753
            wgrad,
754
755
            dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
            grad_bias,
756
757
758
            None,  # is_first_microbatch
            None,  # fp8
            None,  # fp8_calibration
759
760
761
762
            None,  # input_quantizer
            None,  # weight_quantizer
            None,  # output_quantizer
            None,  # grad_input_quantizer
763
764
            None,  # grad_weight_quantizer
            None,  # grad_output_quantizer
765
766
767
768
769
770
771
772
773
            None,  # fuse_wgrad_accumulation
            None,  # cpu_offloading
            None,  # tp_group
            None,  # tp_size
            None,  # sequence_parallel
            None,  # tensor_parallel
            None,  # activation_dtype
            None,  # parallel_mode
            None,  # is_grad_enabled
774
775
776
777
778
779
            None,  # ub_overlap_rs_fprop
            None,  # ub_overlap_ag_dgrad
            None,  # ub_overlap_ag_fprop
            None,  # ub_overlap_rs_dgrad
            None,  # ub_bulk_dgrad
            None,  # ub_bulk_wgrad
780
            None,  # ub_name
781
            None,  # fp8_output
782
            None,  # fsdp_group
783
784
            None,  # module
            None,  # skip_fp8_weight_update
785
            None,  # debug
786
787
788
789
        )


class Linear(TransformerEngineBaseModule):
790
    """Applies a linear transformation to the incoming data :math:`y = xA^T + b`
791
792
793
794
795
796
797
798
799
800
801
802
803
804

    On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
805
    get_rng_state_tracker : Callable, default = `None`
806
                 used to get the random number generator state tracker for initializing weights.
807
808
    rng_tracker_name : str, default = `None`
                 the param passed to get_rng_state_tracker to get the specific rng tracker.
cyanguwa's avatar
cyanguwa committed
809
    parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
810
811
812
813
814
815
816
                      Configuration for splitting the weight and bias tensors along dim 0 into
                      multiple PyTorch parameters. If a list or tuple of strings is provided,
                      they are used to make the names of equally-sized parameters. If a dict
                      (preferably an OrderedDict) is provided, the keys are used as names and
                      values as split sizes along dim 0. The resulting parameters will have
                      names that end in `_weight` or `_bias`, so trailing underscores are
                      stripped from any provided names.
817
    device : Union[torch.device, str], default = "cuda"
818
          The device on which the parameters of the model will be allocated. It is the user's
819
820
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
821
822
    name: str, default = `None`
        name of the module, currently used for debugging purposes.
823
824
825
826
827
828
829
830
831
832
833
834
835

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
836
    parallel_mode : {None, 'column', 'row'}, default = `None`
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
854
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
855
856
857
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
858

859
860
861
862
863
864
865
866
867
868
869
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
870
        rng_tracker_name: Optional[str] = None,
871
872
873
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
874
        params_dtype: Optional[torch.dtype] = None,
875
        parallel_mode: Optional[str] = None,
cyanguwa's avatar
cyanguwa committed
876
        parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
877
        device: Union[torch.device, str] = "cuda",
878
        ub_overlap_ag: bool = False,
879
        ub_overlap_rs: bool = False,
880
        ub_overlap_rs_dgrad: bool = False,
881
882
        ub_bulk_dgrad: bool = False,
        ub_bulk_wgrad: bool = False,
883
        ub_name: Optional[str] = None,
884
        name: Optional[str] = None,
885
886
    ) -> None:
        super().__init__()
887
888

        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
889
890
891
892
893
894
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
        self.apply_bias = bias and not return_bias
895
896
        self.get_rng_state_tracker = get_rng_state_tracker
        self.rng_tracker_name = rng_tracker_name
897
898
899
900
        self.name = name

        if TEDebugState.debug_enabled:
            self._turn_off_unsupported_features_in_debug()  # turn off userbuffers
901

902
903
        if device == "meta":
            assert parameters_split is None, "Cannot split module parameters on 'meta' device."
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

925
        # Column parallel TP overlap options
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
        self.ub_overlap_ag_fprop = (
            self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_ag
        )
        self.ub_overlap_rs_dgrad = (
            self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_rs_dgrad
        )
        self.ub_bulk_dgrad = (
            self.parallel_mode == "column"
            and self.sequence_parallel
            and ub_bulk_dgrad
            and not self.ub_overlap_rs_dgrad
        )
        self.ub_bulk_wgrad = (
            self.parallel_mode == "column"
            and self.sequence_parallel
            and ub_bulk_wgrad
            and not self.ub_overlap_rs_dgrad
        )
944
945

        # Row parallel TP overlap options
946
947
948
949
950
951
        self.ub_overlap_rs_fprop = (
            self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_rs
        )
        self.ub_overlap_ag_dgrad = (
            self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_ag
        )
952
953
954
955
956
957
958
959
960
961
962
963
964
965

        if any(
            [
                self.ub_overlap_rs_fprop,
                self.ub_overlap_ag_dgrad,
                self.ub_overlap_ag_fprop,
                self.ub_overlap_rs_dgrad,
                self.ub_bulk_dgrad,
                self.ub_bulk_wgrad,
            ]
        ):
            assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized."
        self.ub_name = ub_name

966
967
968
        # Initialize params in FP8
        with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()

969
970
971
972
973
974
975
976
        # Contiguous buffers for params
        weight_tensor = torch.empty(
            self.out_features,
            self.in_features,
            device=device,
            dtype=params_dtype,
        )
        bias_tensor = None
977
        if self.use_bias:
978
979
980
981
982
            bias_tensor = torch.empty(
                self.out_features,
                device=device,
                dtype=params_dtype,
            )
983

984
985
986
987
        # Configure parameter splits
        self.weight_names = []
        self.bias_names = []
        self.parameter_split_sizes = []
988
        if parameters_split is None:
989
990
991
992
993
994
            # Split into a single parameter by default
            self.weight_names = ["weight"]
            self.bias_names = ["bias"]
            self.parameter_split_sizes = [out_features]
        elif not parameters_split:
            raise ValueError("Cannot split weight buffer into 0 parameters")
cyanguwa's avatar
cyanguwa committed
995
        elif isinstance(parameters_split, dict):
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
            # Split parameters with provided sizes
            for name, split_size in parameters_split.items():
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
        elif all(isinstance(name, str) for name in parameters_split):
            # Split parameters evenly
            split_size = out_features // len(parameters_split)
            for name in parameters_split:
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
cyanguwa's avatar
cyanguwa committed
1008
        else:
1009
            raise TypeError("Invalid configuration for parameters split")
1010

1011
1012
1013
1014
1015
1016
        # Make sure parameter splits are valid
        if sum(self.parameter_split_sizes) != out_features:
            raise ValueError(
                f"Trying to split weight buffer ({out_features=}) "
                f"with split sizes {self.parameter_split_sizes}"
            )
1017

1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
        # Adjust parameter splits for tensor-parallel distribution
        if self.parallel_mode == "column":
            for i, size in enumerate(self.parameter_split_sizes):
                if size % self.tp_size != 0:
                    raise RuntimeError(
                        f"Attempting to distribute a parameter with out_features={size} "
                        f"between {self.tp_size} tensor-parallel processes"
                    )
                self.parameter_split_sizes[i] = size // self.tp_size

1028
1029
1030
1031
1032
        # Construct weight parameters
        # Note: Register weights together so that they are adjacent to
        # each other in Linear.parameters(). This makes it more likely
        # that they will stay contiguous if the weights are
        # manipulated externally, e.g. by FSDP.
1033
1034
1035
1036
1037
1038
1039
1040
        offset = 0
        for i, split_size in enumerate(self.parameter_split_sizes):
            split_start = offset
            offset += split_size
            split_end = offset

            # Check if parameters are subviews of buffers
            is_subview = (split_start, split_end) != (0, self.out_features)
1041
            if is_subview and with_fp8_params:
1042
1043
1044
                raise RuntimeError(
                    "Splitting QuantizedTensor into multiple params is not supported"
                )
1045

1046
            # Construct weight parameter
1047
1048
1049
1050
1051
1052
1053
            self.register_parameter(
                self.weight_names[i],
                torch.nn.Parameter(weight_tensor[split_start:split_end]),
                init_fn=init_method,
                get_rng_state_tracker=get_rng_state_tracker,
                fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
            )
1054

1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
        # Construct bias parameters if needed
        if self.use_bias:
            offset = 0
            for i, split_size in enumerate(self.parameter_split_sizes):
                split_start = offset
                offset += split_size
                split_end = offset
                self.register_parameter(
                    self.bias_names[i],
                    torch.nn.Parameter(bias_tensor[split_start:split_end]),
                    init_fn=init_method_constant(0.0),
                )
        else:
            for name in self.bias_names:
                bias = torch.Tensor().to(dtype=params_dtype, device=device)
                setattr(self, name, bias)
cyanguwa's avatar
cyanguwa committed
1071

1072
        if with_fp8_params:
1073
1074
            self.init_fp8_metadata()

1075
        self.reset_parameters(defer_init=device == "meta")
1076

1077
1078
1079
1080
1081
1082
1083
        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.apply_bias:
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
    def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
        """Init scales and amaxes for fwd | bwd."""
        super().set_meta_tensor(fwd, recipe)

        # customize quantizers based on each recipe & layer configs
        recipe = FP8GlobalStateManager.get_fp8_recipe()
        if recipe.float8_current_scaling():
            self._customize_quantizers_float8_current_scaling(fwd, recipe)
        # elif for other recipes (mxfp8, etc.)

1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
    def reset_parameters(self, defer_init=False):
        super().reset_parameters(defer_init=defer_init)

        if not defer_init:
            # Set parallelism attributes for linear weights
            for weight in self.weight_names:
                set_tensor_model_parallel_attributes(
                    tensor=getattr(self, weight),
                    is_parallel=True,
                    dim=1 if self.parallel_mode == "row" else 0,
                    stride=1,
                )

            # Set parallelism attributes for linear biases
            if self.use_bias:
                for bias in self.bias_names:
                    if self.parallel_mode == "row":
                        setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel)
                    elif self.parallel_mode == "column":
                        set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)

1115
    @no_torch_dynamo()
1116
1117
1118
1119
    def forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Optional[bool] = None,
1120
        fp8_output: Optional[bool] = False,
1121
        fp8_grad: Optional[bool] = False,
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """
1144
1145
1146
1147
        debug = TEDebugState.debug_enabled
        if debug:
            self._validate_name()

1148
1149
1150
1151
        if FP8GlobalStateManager.fp8_graph_capturing():
            skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
        else:
            skip_fp8_weight_update = None
1152
1153
1154
        if skip_fp8_weight_update is not None:
            is_first_microbatch = False

1155
1156
1157
1158
1159
1160
1161
        if self.ub_overlap_rs_fprop:
            if get_ub(self.ub_name + "_fprop").is_fp8_ubuf():
                fp8_output = True
        if self.ub_overlap_rs_dgrad:
            if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
                fp8_grad = True

1162
1163
        with self.prepare_forward(
            inp,
1164
            allow_non_contiguous=isinstance(inp, QuantizedTensor),
1165
        ) as inp:
1166
1167

            # Get concatenated weight and bias tensors
1168
            unfused_weights = [getattr(self, name) for name in self.weight_names]
1169
            if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
1170
1171
1172
                if self.fp8:
                    if len(unfused_weights) != 1:
                        raise RuntimeError(
1173
                            "Splitting QuantizedTensor into multiple params is not supported"
1174
1175
                        )
                else:
1176
                    unfused_weights = [w.dequantize() for w in unfused_weights]
1177
            weight_tensor = noop_cat(unfused_weights)
1178
            if self.use_bias:
1179
                bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
1180
            else:
1181
1182
                bias_tensor = None

1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
            quantizers = (
                self._get_quantizers(fp8_output, fp8_grad)
                if not debug
                else self._get_debug_quantizers(fp8_output, fp8_grad)
            )
            if debug:
                if not any_feature_enabled(quantizers):
                    # If no feature is used, then run faster implementation with debug = False.
                    quantizers = self._get_quantizers(fp8_output, fp8_grad)
                    debug = False

                if isinstance(weight_tensor, QuantizedTensor):
                    raise RuntimeError("FP8 weights are not supported in debug mode.")

1197
1198
1199
1200
1201
            (
                input_quantizer,
                weight_quantizer,
                output_quantizer,
                grad_input_quantizer,
1202
1203
1204
                grad_weight_quantizer,
                grad_output_quantizer,
            ) = quantizers
1205
1206
1207
1208
1209
1210

            # Make sure weight tensor has correct quantizer
            # Note: Quantizer might have changed if quantization
            # recipe changed
            if weight_quantizer is not None and isinstance(weight_tensor, QuantizedTensor):
                weight_tensor._quantizer = weight_quantizer
1211

1212
1213
1214
1215
1216
1217
1218
1219
1220
            if torch.is_grad_enabled():
                linear_fn = _Linear.apply
                args = []
            else:
                linear_fn = _Linear.forward
                args = [None]
            args += (
                weight_tensor,
                inp,
1221
                bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
1222
1223
1224
                is_first_microbatch,
                self.fp8,
                self.fp8_calibration,
1225
1226
1227
1228
                input_quantizer,
                weight_quantizer,
                output_quantizer,
                grad_input_quantizer,
1229
1230
                grad_weight_quantizer,
                grad_output_quantizer,
1231
                self.fuse_wgrad_accumulation,
1232
                is_cpu_offload_enabled(),
1233
1234
1235
1236
1237
1238
1239
                self.tp_group,
                self.tp_size,
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
                torch.is_grad_enabled(),
1240
1241
1242
1243
1244
1245
                self.ub_overlap_rs_fprop,
                self.ub_overlap_ag_dgrad,
                self.ub_overlap_ag_fprop,
                self.ub_overlap_rs_dgrad,
                self.ub_bulk_dgrad,
                self.ub_bulk_wgrad,
1246
                self.ub_name,
1247
                fp8_output,
1248
                self.fsdp_group,
1249
1250
                self,
                skip_fp8_weight_update,
1251
                debug,
1252
1253
1254
1255
1256
1257
1258
1259
            )
            out = linear_fn(*args)
        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        return out
1260
1261
1262

    def _get_quantizers(self, fp8_output, fp8_grad):
        if not self.fp8:
1263
            return [None] * 6
1264
        grad_input_quantizer = None
1265
        grad_weight_quantizer = None
1266
1267
1268
        grad_output_quantizer = None
        output_quantizer = None
        input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
1269
        input_quantizer.internal = False
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
        weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
        weight_quantizer.internal = True
        if fp8_output:
            output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
        if torch.is_grad_enabled():
            grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
            grad_output_quantizer.internal = True
            if fp8_grad:
                grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
        return (
            input_quantizer,
            weight_quantizer,
            output_quantizer,
            grad_input_quantizer,
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
            grad_weight_quantizer,
            grad_output_quantizer,
        )

    def _get_debug_quantizers(self, fp8_output, fp8_grad):
        original_quantizers = self._get_quantizers(fp8_output, fp8_grad)
        assert TEDebugState.debug_enabled
        from ...debug.pytorch.debug_quantization import DebugQuantizer

        names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
        return tuple(
            DebugQuantizer(self.name, name, q, self.tp_group)
            for name, q in zip(names, original_quantizers)
1297
        )
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344

    def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
        """Customize quantizers based on current scaling recipe + linear."""
        assert (
            recipe.float8_current_scaling()
        ), "current scaling recipe quantizer customization here"
        if fwd:
            # set configs about amax epsilon and power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_INPUT
            ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_INPUT
            ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon
            # also set weight quantizer with same amax_epsilon & power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_WEIGHT
            ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_WEIGHT
            ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
            # paralle related
            if self.sequence_parallel and self.parallel_mode == "column":
                # customize input_quantizer with amax reduction TP group
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].with_amax_reduction = True
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].amax_reduction_group = self.tp_group
        else:
            # set grad_output_quantizer with amax epsilon and power_2_scale
            self.quantizers["scaling_bwd"][
                tex.FP8BwdTensors.GRAD_OUTPUT1
            ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
            self.quantizers["scaling_bwd"][
                tex.FP8BwdTensors.GRAD_OUTPUT1
            ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
            # parallel related
            if self.sequence_parallel and self.parallel_mode == "row":
                # customize grad_output_quantizer with amax reduction TP group
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].with_amax_reduction = True
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].amax_reduction_group = self.tp_group