linear.py 54.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Linear API"""
6
from typing import Callable, Dict, Optional, Tuple, Union
7
8
from functools import reduce
from operator import mul as multiply_op
9
10
11

import torch

12
import transformer_engine_torch as tex
13

14
from transformer_engine.common.recipe import Recipe
15
16
17
18
from .base import (
    get_workspace,
    get_ub,
    TransformerEngineBaseModule,
19
    get_dummy_wgrad,
20
21
22
23
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
24
25
from ._common import noop_cat, _fix_gathered_fp8_transpose
from ..fp8 import FP8GlobalStateManager
26
27
from ..utils import (
    cast_if_needed,
28
    clear_tensor_data,
29
    divide,
30
    init_method_constant,
31
    non_tn_fp8_gemm_supported,
32
    assert_dim_for_fp8_exec,
33
34
35
    nvtx_range_pop,
    nvtx_range_push,
    requires_grad,
36
37
38
39
40
41
42
)
from ..distributed import (
    set_tensor_model_parallel_attributes,
    get_distributed_world_size,
    allreduce,
    reduce_scatter_along_first_dim,
    gather_along_first_dim,
43
    is_fp8_activation_recompute_enabled,
44
    in_fp8_activation_recompute_phase,
45
46
    _fsdp_scatter_tensors,
    _fsdp_gather_tensors,
47
48
)
from ..cpp_extensions import (
49
    general_gemm,
50
)
51
from ..constants import GemmParallelModes, dist_group_type
52
from ..jit import no_torch_dynamo
53
from ..graph import is_graph_capturing
54
55
56
57
58
59
from ..tensor.quantized_tensor import (
    QuantizedTensor,
    Quantizer,
    prepare_for_saving,
    restore_from_saved,
)
60
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
61
from ..tensor.mxfp8_tensor import MXFP8Quantizer
62
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
63
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
64
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
65

66

67
68
69
70
71
72
73
74
75
76
77
__all__ = ["Linear"]


class _Linear(torch.autograd.Function):
    """Linear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
78
        weight: torch.Tensor,
79
        inp: torch.Tensor,
80
        bias: Optional[torch.Tensor],
81
82
83
        is_first_microbatch: Union[bool, None],
        fp8: bool,
        fp8_calibration: bool,
84
85
86
87
88
        input_quantizer: Optional[Quantizer],
        weight_quantizer: Optional[Quantizer],
        output_quantizer: Optional[Quantizer],
        grad_output_quantizer: Optional[Quantizer],
        grad_input_quantizer: Optional[Quantizer],
89
        fuse_wgrad_accumulation: bool,
90
        cpu_offloading: bool,
91
92
93
94
95
96
97
        tp_group: Union[dist_group_type, None],
        tp_size: int,
        sequence_parallel: bool,
        tensor_parallel: bool,
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
        is_grad_enabled: bool,
98
99
100
101
102
103
        ub_overlap_rs_fprop: bool,
        ub_overlap_ag_dgrad: bool,
        ub_overlap_ag_fprop: bool,
        ub_overlap_rs_dgrad: bool,
        ub_bulk_dgrad: bool,
        ub_bulk_wgrad: bool,
104
        ub_name: str,
105
        fp8_output: bool,  # pylint: disable=unused-argument
106
        fsdp_group: Union[dist_group_type, None],
107
108
        module: torch.nn.Module,
        skip_fp8_weight_update: bool,
109
    ) -> torch.Tensor:
110
        # pylint: disable=missing-function-docstring
111

112
113
114
115
116
        # NVTX label for profiling
        nvtx_label = "transformer_engine._Linear.forward"
        if ub_name is not None:
            nvtx_label = f"{nvtx_label}.{ub_name}"

117
        # Make sure input dimensions are compatible
118
119
120
        out_features, in_features = weight.shape
        inp_shape = inp.shape
        assert inp_shape[-1] == in_features, "GEMM not possible"
121

122
        tp_world_size = get_distributed_world_size(tp_group)
123
124
125
126
        backward_needs_input = is_grad_enabled and weight.requires_grad

        # Prepare input tensor
        # Note: Cast to expected dtype and perform tensor-parallel communication
127
        nvtx_range_push(f"{nvtx_label}.input_cast_comm")
128
        inputmat = inp.view(-1, in_features)
129
130
131
132
133
        inputmat_total = None
        with_input_all_gather_nccl = (
            parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
        )
        own_quantized_input = False
134
135
136
137
        # TODO(kwyss): Support FP8 allgather for FP8 block quantization.
        force_hp_input_gather = (
            fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer)
        )  # Perform TP communication in high precision.
138
        if fp8:
139
            assert_dim_for_fp8_exec(inputmat, weight)
140
141
            if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not (
                FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
142
143
            ):
                raise NotImplementedError(
144
145
                    "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
                    " current scaling"
146
                )
147

148
149
150
            if input_quantizer is None:
                raise ValueError("Missing quantizer for input tensor")
            if with_input_all_gather_nccl:
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
                if force_hp_input_gather:
                    input_quantizer.set_usage(rowwise=True, columnwise=False)
                    inputmat_total, _ = gather_along_first_dim(
                        inputmat, tp_group, quantizer=input_quantizer
                    )
                else:
                    if not isinstance(inputmat, QuantizedTensor):
                        columnwise_usage = backward_needs_input and isinstance(
                            input_quantizer, MXFP8Quantizer
                        )
                        # force_hp_input_gather should enforce this
                        assert not isinstance(input_quantizer, Float8BlockQuantizer)
                        input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
                        inputmat = input_quantizer(inputmat)
                        own_quantized_input = True
                    input_quantizer.set_usage(rowwise=True, columnwise=False)
                    inputmat_total, _ = gather_along_first_dim(
                        inputmat,
                        tp_group,
                        quantizer=input_quantizer,
171
                    )
172
            else:
173
174
175
176
177
178
179
180
181
182
183
                if (
                    FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
                    and ub_bulk_dgrad
                ):
                    # reduce duplicated transpose in `_fix_gathered_fp8_transpose`
                    input_quantizer.set_usage(rowwise=True, columnwise=False)
                else:
                    input_quantizer.set_usage(
                        rowwise=True,
                        columnwise=backward_needs_input,
                    )
184
185
                if not isinstance(inputmat, QuantizedTensor):
                    inputmat = input_quantizer(inputmat)
186
                    own_quantized_input = True
187
188
189
                elif backward_needs_input:
                    inputmat.update_usage(rowwise_usage=True, columnwise_usage=True)
                inputmat_total = inputmat
190
        else:
191
192
193
            inputmat = cast_if_needed(inp, activation_dtype)
            if with_input_all_gather_nccl:
                inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
194
            else:
195
                inputmat_total = inputmat
196
        nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
197

198
199
        # Cast weight to expected dtype
        if not fp8:
200
            weightmat = cast_if_needed(weight, activation_dtype)
201
        else:
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
            # Configure quantizer
            if weight_quantizer is not None:
                columnwise_usage = is_grad_enabled and inp.requires_grad
                if not columnwise_usage:
                    columnwise_usage = (
                        is_fp8_activation_recompute_enabled()
                        and not in_fp8_activation_recompute_phase()
                    )
                weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)

            # FP8 cast to workspace buffer
            update_workspace = is_first_microbatch is None or is_first_microbatch
            weightmat = module.get_weight_workspace(
                tensor=weight,
                quantizer=weight_quantizer,
                cache_name=(None if is_first_microbatch is None else "weight"),
                update_workspace=update_workspace,
                skip_update_flag=skip_fp8_weight_update,
                fsdp_group=fsdp_group,
            )
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257

        # Cast bias to expected dtype
        bias_dtype = activation_dtype
        if fp8 and activation_dtype == torch.float32:
            bias_dtype = torch.bfloat16
        bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias

        # Configure output quantizer
        if output_quantizer is not None:
            output_quantizer.set_usage(rowwise=True, columnwise=False)

        # Calibrate quantizers if needed
        if not fp8 and fp8_calibration:
            if input_quantizer is not None:
                input_quantizer.calibrate(inputmat_total)
            if weight_quantizer is not None:
                weight_quantizer.calibrate(weight)

        ub_obj = None
        ub_type = None
        rs_out = None
        out_dtype = activation_dtype
        if ub_overlap_rs_fprop:
            ub_obj = get_ub(ub_name + "_fprop")
            ub_type = tex.CommOverlapType.RS
            out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features]
            rs_out = torch.empty(out_shape, dtype=activation_dtype, device=inputmat_total.device)

        elif ub_overlap_ag_fprop:
            ub_obj = get_ub(ub_name + "_fprop")
            ub_type = tex.CommOverlapType.AG
            if fp8:
                assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer."
            ub_obj.copy_into_buffer(inputmat_total, input_quantizer, local_chunk=True)
            inputmat_total = ub_obj.get_buffer(input_quantizer)

258
        nvtx_range_push(f"{nvtx_label}.gemm")
259
260
261
262
263
264
        fprop_gemm_use_split_accumulator = _2X_ACC_FPROP
        if fp8:
            recipe = FP8GlobalStateManager.get_fp8_recipe()
            if hasattr(recipe, "fp8_gemm_fprop"):
                fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator

265
266
267
268
269
270
271
        out, *_, rs_out = general_gemm(
            weightmat,
            inputmat_total,
            get_workspace(),
            quantization_params=output_quantizer,
            out_dtype=out_dtype,
            bias=bias,
272
            use_split_accumulator=fprop_gemm_use_split_accumulator,
273
274
275
276
            ub=ub_obj,
            ub_type=ub_type,
            extra_output=rs_out,
        )
277
        nvtx_range_pop(f"{nvtx_label}.gemm")
278
279

        if is_grad_enabled:
280
            saved_inputmat = None
281
282
283
284
285

            ctx.backward_input_needs_gather = (
                weight.requires_grad and parallel_mode == "column" and sequence_parallel
            )

286
287
            if backward_needs_input:
                if own_quantized_input and isinstance(inputmat, QuantizedTensor):
288
289
290
291
                    # For sequence parallel in vanilla FP8, rowwise data is
                    # to gather the input. For MXFP8, columnwise only data
                    # can be allgathered.
                    if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather:
292
                        inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
293
294
                if force_hp_input_gather:
                    assert not isinstance(inputmat, QuantizedTensor)
295
                saved_inputmat = inputmat
296

297
298
299
300
301
            # Weight with column-wise usage is needed for dgrad GEMM.
            if inp.requires_grad:
                if isinstance(weightmat, QuantizedTensor):
                    weightmat.update_usage(columnwise_usage=True)

302
303
304
305
306
            if cpu_offloading:
                set_offloading_param(weight, "weight_offloading", True)
                set_offloading_param(weightmat, "weight_offloading", True)
                if saved_inputmat is not None:
                    set_offloading_param(saved_inputmat, "activation_offloading", True)
307

308
309
            # Scatter intermediate/activation tensors saved for the backward pass
            # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights
310
            nvtx_range_push(f"{nvtx_label}.fsdp_scatter")
311
312
313
            ctx.fsdp_group = fsdp_group
            ctx.fsdp_shapes = _fsdp_scatter_tensors(
                fsdp_group,
314
315
                saved_inputmat,
                weightmat if fp8 and not isinstance(weight, QuantizedTensor) else None,
316
            )
317
            nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
318

319
320
321
322
323
324
325
326
327
328
329
            if cpu_offloading:
                ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")

                if ctx.grad_added_to_main_grad:
                    # If you are passing torch.nn.Parameter through the Torch hooks, you will
                    # get back torch.Tensor. Torch rips off the Parameter wrapper.
                    # You need to preserve the weight object to have all the attributes user
                    # sets for the weights. Because of this, it is not recommended to offload
                    # weights if weights are externally touched outside this module
                    ctx.weight_object = weight

330
331
            # TODO(ksivamani): Check memory usage
            tensors_to_save, tensor_objects = prepare_for_saving(
332
                saved_inputmat,
333
                weightmat,
334
                weight,
335
                bias,
336
            )
337
338
            ctx.save_for_backward(*tensors_to_save)
            ctx.tensor_objects = tensor_objects
339

340
341
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
342
343
            ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
            ctx.force_hp_input_gather = force_hp_input_gather
344
345
346
            ctx.input_quantizer = input_quantizer
            ctx.grad_output_quantizer = grad_output_quantizer
            ctx.grad_input_quantizer = grad_input_quantizer
347
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
348
349
350
            if fuse_wgrad_accumulation and weight.requires_grad:
                ctx.main_grad = weight.main_grad

351
            ctx.cpu_offloading = cpu_offloading
352
            ctx.is_first_microbatch = is_first_microbatch
353
            ctx.use_bias = bias is not None
354
355
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
356
            ctx.inp_shape = inp_shape
357
358
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
359
360
361
362
            ctx.ub_overlap_ag = ub_overlap_ag_dgrad
            ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
            ctx.ub_bulk_dgrad = ub_bulk_dgrad
            ctx.ub_bulk_wgrad = ub_bulk_wgrad
363
            ctx.ub_name = ub_name
364
365
            ctx.tp_size = tp_size
            ctx.requires_dgrad = inp.requires_grad
366
            ctx.requires_wgrad = weight.requires_grad
367
            ctx.reduce_and_update_bwd_fp8_tensors = False
368
            ctx.owns_input = saved_inputmat is not inp
369
            if ctx.fp8 and requires_grad(inp, weight, bias):
370
371
372
373
                _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
                ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
                if in_fp8_activation_recompute_phase():
                    FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
374
375

        # Row Parallel Linear
376
377
378
        if ub_overlap_rs_fprop:
            out = rs_out
        elif parallel_mode == "row":
379
            nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
380
            if sequence_parallel:
381
382
383
                out, _ = reduce_scatter_along_first_dim(out, tp_group)
            elif tensor_parallel:
                out, _ = allreduce(out, tp_group)
384
            nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
385

386
387
        out = out.view(-1, *inp_shape[1:-1], out_features)
        return out
388
389

    @staticmethod
390
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
391
        # pylint: disable=missing-function-docstring
392

393
394
395
396
397
        # NVTX label for profiling
        nvtx_label = "transformer_engine._Linear.backward"
        if ctx.ub_name is not None:
            nvtx_label = f"{nvtx_label}.{ctx.ub_name}"

398
        with torch.cuda.nvtx.range("_Linear_backward"):
399
400
401
402
403
404
405
406
407
408
            if (
                ctx.fp8
                and any(
                    [
                        ctx.ub_overlap_ag,
                        ctx.ub_overlap_rs_dgrad,
                        ctx.ub_bulk_dgrad,
                        ctx.ub_bulk_wgrad,
                    ]
                )
409
                and (ctx.fp8_recipe is not None)
410
            ):
411
                if not ctx.fp8_recipe.float8_per_tensor_scaling():
412
                    raise NotImplementedError(
413
414
                        "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
                        " current scaling"
415
                    )
416
417
418
419
420

            saved_tensors = ctx.saved_tensors
            inputmat, weight_fp8, weight, bias = (  # pylint: disable=unbalanced-tuple-unpacking
                restore_from_saved(ctx.tensor_objects, saved_tensors)
            )
421
422
423
            # Delete the references to tensor objects once they've been consumed
            # by the `restore_from_saved` method to construct back the actual tensors.
            ctx.tensor_objects = None
424
425
426
427
428
429
430
431

            # Since main_grad can be modified inplace, it should not be a part of saved_tensors
            main_grad = (
                ctx.main_grad
                if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad
                else None
            )

432
433
434
435
436
            if ctx.cpu_offloading:
                if ctx.grad_added_to_main_grad:
                    weight = ctx.weight_object
                if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
                    weight.main_grad = main_grad
437

438
439
440
            # Gather intermediate/activation tensors if needed
            # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
            #       shards/unshards the base weights so we don't do it ourselves
441
            nvtx_range_push(f"{nvtx_label}.fsdp_gather")
442
443
444
445
            _fsdp_gather_tensors(
                ctx.fsdp_group,
                ctx.fsdp_shapes,
                inputmat,
446
                weight_fp8,
447
            )
448
            nvtx_range_pop(f"{nvtx_label}.fsdp_gather")
449

450
            ctx.ub_obj_gradout = None
451
            ub_obj_dgrad = None
452
            ub_obj_wgrad = None
453
454
            ub_type_dgrad = None
            ub_type_wgrad = None
455
            dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
456
457
            rs_out = None
            dgrad_bulk = None
458
            if ctx.ub_overlap_ag:
459
                # Overlap grad_output all-gather with dgrad compute
460
                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
461
462
                ub_obj_dgrad = ctx.ub_obj_gradout
                ub_type_dgrad = tex.CommOverlapType.AG
463
464
465
466

            elif ctx.ub_overlap_rs_dgrad:
                # Overlap dgrad reduce-scatter with dgrad compute
                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
467
468
                ub_obj_dgrad = ctx.ub_obj_gradout
                ub_type_dgrad = tex.CommOverlapType.RS
469
470
471
472
473
474
475
                rs_out = torch.empty(
                    dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
                )

            else:
                if ctx.ub_bulk_dgrad:
                    # Overlap inputmat all-gather with dgrad compute
476
477
478
479
                    # NOTE: Copying into communication buffer will always prefer rowwise data,
                    #       and will copy columnwise data if rowwise does not exist. In that case,
                    #       the all-gather will apply to the leading dimension of the transpose,
                    #       which then needs to be interleaved correctly before WGRAD.
480
                    ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
481
482
483
                    ub_obj_dgrad = ctx.ub_obj_gradout
                    ub_type_dgrad = tex.CommOverlapType.AG
                    ub_obj_dgrad.copy_into_buffer(inputmat, ctx.input_quantizer, local_chunk=True)
484
485
486
487

                if ctx.ub_bulk_wgrad:
                    # Overlap dgrad reduce-scatter with wgrad compute
                    ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
488
489
490
491
                    ub_type_wgrad = tex.CommOverlapType.RS
                    ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
                    dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer)

492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
            # Configure quantizer for grad output tensor
            # Note: dgrad GEMM requires row-wise usage, wgrad GEMM
            # requires column-wise usage
            if ctx.grad_output_quantizer is not None:
                rowwise_usage = True
                columnwise_usage = True
                if ctx.ub_overlap_ag and isinstance(
                    ctx.grad_output_quantizer,
                    (Float8Quantizer, Float8CurrentScalingQuantizer),
                ):
                    # If data is in FP8 and communication is handled
                    # with Userbuffers, we compute FP8 transposes
                    # manually
                    columnwise_usage = False
                ctx.grad_output_quantizer.set_usage(
                    rowwise=rowwise_usage,
                    columnwise=columnwise_usage,
                )

511
512
            # Prepare grad output tensor
            # Note: Cast to expected dtype and perform tensor-parallel communication
513
            nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
514
515
516
517
            (
                grad_output,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
518
519
520
521
                ctx,
                grad_output,
                ctx.parallel_mode == "row",
                ctx.grad_output_quantizer,
522
            )
523
            nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
524

525
            # Launch tensor-parallel communication for input tensor
526
            inputmat_total = None
527
            inputmat_total_work = None
528
            if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad:
529
530
531
                quantizer = None
                if ctx.fp8:
                    quantizer = ctx.input_quantizer
532
533
534
535
536
537
                    if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
                        # If data is in FP8, we compute FP8 transposes manually
                        quantizer.set_usage(rowwise=True, columnwise=False)
                    else:
                        # wgrad GEMM requires input with column-wise usage
                        quantizer.set_usage(rowwise=False, columnwise=True)
538
                nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
539
                gather_quantizer = None if ctx.force_hp_input_gather else quantizer
540
541
542
543
                inputmat_total, inputmat_total_work = gather_along_first_dim(
                    inputmat,
                    ctx.tp_group,
                    async_op=True,
544
                    quantizer=gather_quantizer,
545
                )
546
                nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
547
548
549
            else:
                inputmat_total = inputmat

550
            # Check whether to output wgrad GEMM directly into main grad
551
552
553
554
555
556
557
            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                )
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation

558
559
560
            # Compute grad input tensor
            dgrad = None
            dgrad_work = None
561
            if ctx.requires_dgrad:
562
563
564
565
566
567

                # Update quantizer
                if ctx.grad_input_quantizer is not None:
                    ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)

                # dgrad GEMM
568
                nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
569
570
571
572
573
574
575
576
                dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
                if ctx.fp8:
                    recipe = ctx.fp8_recipe
                    if hasattr(recipe, "fp8_gemm_dgrad"):
                        dgrad_gemm_use_split_accumulator = (
                            recipe.fp8_gemm_dgrad.use_split_accumulator
                        )

577
578
579
580
581
582
583
584
585
                dgrad, *_, rs_out = general_gemm(
                    weight_fp8,
                    grad_output,
                    get_workspace(),
                    layout="NN",
                    grad=True,
                    quantization_params=ctx.grad_input_quantizer,
                    out=dgrad_bulk,
                    out_dtype=ctx.activation_dtype,
586
                    use_split_accumulator=dgrad_gemm_use_split_accumulator,
587
588
589
590
591
                    ub=ub_obj_dgrad,
                    ub_type=ub_type_dgrad,
                    extra_output=rs_out,
                    bulk_overlap=ctx.ub_bulk_dgrad,
                )
592
                nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
593
594
595
596
597

                # Launch tensor-parallel communication
                if ctx.ub_overlap_rs_dgrad:
                    dgrad = rs_out
                elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad:
598
                    nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
599
600
601
602
603
                    if ctx.sequence_parallel:
                        dgrad, dgrad_work = reduce_scatter_along_first_dim(
                            dgrad,
                            ctx.tp_group,
                            async_op=True,
604
                        )
605
                    else:
606
                        dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
607
                    nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
608

609
610
611
            # Compute grad weight tensor
            wgrad = None
            if ctx.requires_wgrad:
612
613

                # Synchronize tensor-parallel communication for input tensor
614
615
616
617
618
619
620
621
622
623
624
625
626
                if ctx.ub_bulk_dgrad:
                    inputmat_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer)
                    if ctx.fp8:
                        if inputmat._data is None:
                            # All-gather executed on columnwise data and result is in rowwise data,
                            # so we need to fix the interleaving before WGRAD.
                            inputmat_total = _fix_gathered_fp8_transpose(
                                inputmat_total, ctx.tp_size
                            )
                        elif not non_tn_fp8_gemm_supported():
                            # FP8 GEMM on Hopper only supports TN layout so the gathered input must
                            # have a valid transpose.
                            inputmat_total._create_transpose()
627
628
629
                if inputmat_total_work is not None:
                    inputmat_total_work.wait()
                    inputmat_total_work = None
630
631
632
633
634
635
636
                if ctx.input_quantizer is not None and not isinstance(
                    inputmat_total, QuantizedTensor
                ):
                    # Async gather in BF16 does not asynchronously
                    # call quantizer after gather.
                    ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
                    inputmat_total = ctx.input_quantizer(inputmat_total)
637

638
639
640
                # Make sure GEMM inputs have required data
                if isinstance(inputmat_total, QuantizedTensor):
                    inputmat_total.update_usage(columnwise_usage=True)
641
                if isinstance(grad_output, QuantizedTensor):
642
                    grad_output.update_usage(columnwise_usage=True)
643

644
645
646
647
648
649
650
651
652
                # Figure out whether to use split accumulator
                use_split_accumulator = _2X_ACC_WGRAD
                if ctx.fp8:
                    recipe = ctx.fp8_recipe
                    if hasattr(recipe, "fp8_gemm_wgrad"):
                        use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator

                # Output buffer for overlapping grad input
                # reduce-scatter with wgrad GEMM
653
654
655
                if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
                    rs_out = torch.empty(
                        dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
656
657
                    )

658
659
                # wgrad GEMM
                # Note: Fuse with bgrad computation if needed
660
                nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
661
662
663
664
665
666
667
668
669
670
671
                wgrad, grad_bias_, _, rs_out = general_gemm(
                    inputmat_total,
                    grad_output,
                    get_workspace(),
                    layout="NT",
                    grad=True,
                    out_dtype=(
                        main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
                    ),
                    bias=(bias if (grad_bias is None and not ctx.fp8) else None),
                    out=main_grad if ctx.fuse_wgrad_accumulation else None,
672
                    use_split_accumulator=use_split_accumulator,
673
674
675
676
677
678
                    accumulate=accumulate_wgrad_into_param_main_grad,
                    ub=ub_obj_wgrad,
                    ub_type=ub_type_wgrad,
                    extra_output=rs_out,
                    bulk_overlap=ctx.ub_bulk_wgrad,
                )
679
                nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
680

681
682
683
                if ctx.ub_bulk_wgrad:
                    if ub_obj_wgrad.is_fp8_ubuf():
                        dgrad = rs_out
684
                    else:
685
                        dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True)
686

687
688
689
                if grad_bias is None:
                    grad_bias = grad_bias_
                del grad_bias_
690

691
                # Deallocate input tensor
692
693
                if ctx.owns_input:
                    clear_tensor_data(inputmat_total)
694

695
            # Don't return grad bias if not needed
696
697
698
            if not ctx.use_bias:
                grad_bias = None

699
            # Make sure all tensor-parallel communication is finished
700
701
702
703
704
705
706
707
            if inputmat_total_work is not None:
                inputmat_total_work.wait()
                inputmat_total_work = None
            if dgrad_work is not None:
                dgrad_work.wait()
                dgrad_work = None

        if ctx.requires_wgrad:
708
            # Handle custom DDP from mcore.
709
710
711
712
713
            if (
                ctx.fuse_wgrad_accumulation
                and weight is not None
                and hasattr(weight, "grad_added_to_main_grad")
            ):
714
                weight.grad_added_to_main_grad = True
715
                if getattr(weight, "zero_out_wgrad", False):
716
717
718
719
                    wgrad = get_dummy_wgrad(
                        list(weight.main_grad.shape),
                        weight.dtype,
                        zero=True,
720
                    )
721
                else:
722
723
724
                    wgrad = get_dummy_wgrad(
                        list(weight.main_grad.shape),
                        weight.dtype,
725
                    )
726
727
728
729
            elif ctx.fuse_wgrad_accumulation:
                wgrad = None
        else:
            wgrad = None
730

731
        if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
732
            nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
733
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
734
            nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
735

736
        # Scatter fp8 weight buffers
737
        if ctx.fp8 and not isinstance(weight, QuantizedTensor):
738
            _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
739
        return (
740
            wgrad,
741
742
            dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
            grad_bias,
743
744
745
            None,  # is_first_microbatch
            None,  # fp8
            None,  # fp8_calibration
746
747
748
749
750
            None,  # input_quantizer
            None,  # weight_quantizer
            None,  # output_quantizer
            None,  # grad_output_quantizer
            None,  # grad_input_quantizer
751
752
753
754
755
756
757
758
759
            None,  # fuse_wgrad_accumulation
            None,  # cpu_offloading
            None,  # tp_group
            None,  # tp_size
            None,  # sequence_parallel
            None,  # tensor_parallel
            None,  # activation_dtype
            None,  # parallel_mode
            None,  # is_grad_enabled
760
761
762
763
764
765
            None,  # ub_overlap_rs_fprop
            None,  # ub_overlap_ag_dgrad
            None,  # ub_overlap_ag_fprop
            None,  # ub_overlap_rs_dgrad
            None,  # ub_bulk_dgrad
            None,  # ub_bulk_wgrad
766
            None,  # ub_name
767
            None,  # fp8_output
768
            None,  # fsdp_group
769
770
            None,  # module
            None,  # skip_fp8_weight_update
771
772
773
774
        )


class Linear(TransformerEngineBaseModule):
775
    """Applies a linear transformation to the incoming data :math:`y = xA^T + b`
776
777
778
779
780
781
782
783
784
785
786
787
788
789

    On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
790
    get_rng_state_tracker : Callable, default = `None`
791
                 used to get the random number generator state tracker for initializing weights.
792
793
    rng_tracker_name : str, default = `None`
                 the param passed to get_rng_state_tracker to get the specific rng tracker.
cyanguwa's avatar
cyanguwa committed
794
    parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
795
796
797
798
799
800
801
                      Configuration for splitting the weight and bias tensors along dim 0 into
                      multiple PyTorch parameters. If a list or tuple of strings is provided,
                      they are used to make the names of equally-sized parameters. If a dict
                      (preferably an OrderedDict) is provided, the keys are used as names and
                      values as split sizes along dim 0. The resulting parameters will have
                      names that end in `_weight` or `_bias`, so trailing underscores are
                      stripped from any provided names.
802
    device : Union[torch.device, str], default = "cuda"
803
          The device on which the parameters of the model will be allocated. It is the user's
804
805
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
806
807
808
809
810
811
812
813
814
815
816
817
818

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
819
    parallel_mode : {None, 'column', 'row'}, default = `None`
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
837
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
838
839
840
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
841

842
843
844
845
846
847
848
849
850
851
852
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
853
        rng_tracker_name: Optional[str] = None,
854
855
856
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
857
        params_dtype: Optional[torch.dtype] = None,
858
        parallel_mode: Optional[str] = None,
cyanguwa's avatar
cyanguwa committed
859
        parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
860
        device: Union[torch.device, str] = "cuda",
861
        ub_overlap_ag: bool = False,
862
        ub_overlap_rs: bool = False,
863
        ub_overlap_rs_dgrad: bool = False,
864
865
        ub_bulk_dgrad: bool = False,
        ub_bulk_wgrad: bool = False,
866
        ub_name: Optional[str] = None,
867
868
    ) -> None:
        super().__init__()
869
870

        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
871
872
873
874
875
876
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
        self.apply_bias = bias and not return_bias
877
878
        self.get_rng_state_tracker = get_rng_state_tracker
        self.rng_tracker_name = rng_tracker_name
879

880
881
        if device == "meta":
            assert parameters_split is None, "Cannot split module parameters on 'meta' device."
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

903
        # Column parallel TP overlap options
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
        self.ub_overlap_ag_fprop = (
            self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_ag
        )
        self.ub_overlap_rs_dgrad = (
            self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_rs_dgrad
        )
        self.ub_bulk_dgrad = (
            self.parallel_mode == "column"
            and self.sequence_parallel
            and ub_bulk_dgrad
            and not self.ub_overlap_rs_dgrad
        )
        self.ub_bulk_wgrad = (
            self.parallel_mode == "column"
            and self.sequence_parallel
            and ub_bulk_wgrad
            and not self.ub_overlap_rs_dgrad
        )
922
923

        # Row parallel TP overlap options
924
925
926
927
928
929
        self.ub_overlap_rs_fprop = (
            self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_rs
        )
        self.ub_overlap_ag_dgrad = (
            self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_ag
        )
930
931
932
933
934
935
936
937
938
939
940
941
942
943

        if any(
            [
                self.ub_overlap_rs_fprop,
                self.ub_overlap_ag_dgrad,
                self.ub_overlap_ag_fprop,
                self.ub_overlap_rs_dgrad,
                self.ub_bulk_dgrad,
                self.ub_bulk_wgrad,
            ]
        ):
            assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized."
        self.ub_name = ub_name

944
945
946
        # Initialize params in FP8
        with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()

947
948
949
950
951
952
953
954
        # Contiguous buffers for params
        weight_tensor = torch.empty(
            self.out_features,
            self.in_features,
            device=device,
            dtype=params_dtype,
        )
        bias_tensor = None
955
        if self.use_bias:
956
957
958
959
960
            bias_tensor = torch.empty(
                self.out_features,
                device=device,
                dtype=params_dtype,
            )
961

962
963
964
965
        # Configure parameter splits
        self.weight_names = []
        self.bias_names = []
        self.parameter_split_sizes = []
966
        if parameters_split is None:
967
968
969
970
971
972
            # Split into a single parameter by default
            self.weight_names = ["weight"]
            self.bias_names = ["bias"]
            self.parameter_split_sizes = [out_features]
        elif not parameters_split:
            raise ValueError("Cannot split weight buffer into 0 parameters")
cyanguwa's avatar
cyanguwa committed
973
        elif isinstance(parameters_split, dict):
974
975
976
977
978
979
980
981
982
983
984
985
            # Split parameters with provided sizes
            for name, split_size in parameters_split.items():
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
        elif all(isinstance(name, str) for name in parameters_split):
            # Split parameters evenly
            split_size = out_features // len(parameters_split)
            for name in parameters_split:
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
cyanguwa's avatar
cyanguwa committed
986
        else:
987
            raise TypeError("Invalid configuration for parameters split")
988

989
990
991
992
993
994
        # Make sure parameter splits are valid
        if sum(self.parameter_split_sizes) != out_features:
            raise ValueError(
                f"Trying to split weight buffer ({out_features=}) "
                f"with split sizes {self.parameter_split_sizes}"
            )
995

996
997
998
999
1000
1001
1002
1003
1004
1005
        # Adjust parameter splits for tensor-parallel distribution
        if self.parallel_mode == "column":
            for i, size in enumerate(self.parameter_split_sizes):
                if size % self.tp_size != 0:
                    raise RuntimeError(
                        f"Attempting to distribute a parameter with out_features={size} "
                        f"between {self.tp_size} tensor-parallel processes"
                    )
                self.parameter_split_sizes[i] = size // self.tp_size

1006
1007
1008
1009
1010
        # Construct weight parameters
        # Note: Register weights together so that they are adjacent to
        # each other in Linear.parameters(). This makes it more likely
        # that they will stay contiguous if the weights are
        # manipulated externally, e.g. by FSDP.
1011
1012
1013
1014
1015
1016
1017
1018
        offset = 0
        for i, split_size in enumerate(self.parameter_split_sizes):
            split_start = offset
            offset += split_size
            split_end = offset

            # Check if parameters are subviews of buffers
            is_subview = (split_start, split_end) != (0, self.out_features)
1019
            if is_subview and with_fp8_params:
1020
1021
1022
                raise RuntimeError(
                    "Splitting QuantizedTensor into multiple params is not supported"
                )
1023

1024
            # Construct weight parameter
1025
1026
1027
1028
1029
1030
1031
            self.register_parameter(
                self.weight_names[i],
                torch.nn.Parameter(weight_tensor[split_start:split_end]),
                init_fn=init_method,
                get_rng_state_tracker=get_rng_state_tracker,
                fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
            )
1032

1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
        # Construct bias parameters if needed
        if self.use_bias:
            offset = 0
            for i, split_size in enumerate(self.parameter_split_sizes):
                split_start = offset
                offset += split_size
                split_end = offset
                self.register_parameter(
                    self.bias_names[i],
                    torch.nn.Parameter(bias_tensor[split_start:split_end]),
                    init_fn=init_method_constant(0.0),
                )
        else:
            for name in self.bias_names:
                bias = torch.Tensor().to(dtype=params_dtype, device=device)
                setattr(self, name, bias)
cyanguwa's avatar
cyanguwa committed
1049

1050
        if with_fp8_params:
1051
1052
            self.init_fp8_metadata()

1053
        self.reset_parameters(defer_init=device == "meta")
1054

1055
1056
1057
1058
1059
1060
1061
        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.apply_bias:
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
    def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
        """Init scales and amaxes for fwd | bwd."""
        super().set_meta_tensor(fwd, recipe)

        # customize quantizers based on each recipe & layer configs
        recipe = FP8GlobalStateManager.get_fp8_recipe()
        if recipe.float8_current_scaling():
            self._customize_quantizers_float8_current_scaling(fwd, recipe)
        # elif for other recipes (mxfp8, etc.)

1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
    def reset_parameters(self, defer_init=False):
        super().reset_parameters(defer_init=defer_init)

        if not defer_init:
            # Set parallelism attributes for linear weights
            for weight in self.weight_names:
                set_tensor_model_parallel_attributes(
                    tensor=getattr(self, weight),
                    is_parallel=True,
                    dim=1 if self.parallel_mode == "row" else 0,
                    stride=1,
                )

            # Set parallelism attributes for linear biases
            if self.use_bias:
                for bias in self.bias_names:
                    if self.parallel_mode == "row":
                        setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel)
                    elif self.parallel_mode == "column":
                        set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)

1093
    @no_torch_dynamo()
1094
1095
1096
1097
    def forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Optional[bool] = None,
1098
        fp8_output: Optional[bool] = False,
1099
        fp8_grad: Optional[bool] = False,
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """
1122
1123
1124
1125
        if FP8GlobalStateManager.fp8_graph_capturing():
            skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
        else:
            skip_fp8_weight_update = None
1126
1127
1128
        if skip_fp8_weight_update is not None:
            is_first_microbatch = False

1129
1130
1131
1132
1133
1134
1135
        if self.ub_overlap_rs_fprop:
            if get_ub(self.ub_name + "_fprop").is_fp8_ubuf():
                fp8_output = True
        if self.ub_overlap_rs_dgrad:
            if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
                fp8_grad = True

1136
1137
        with self.prepare_forward(
            inp,
1138
            allow_non_contiguous=isinstance(inp, QuantizedTensor),
1139
        ) as inp:
1140
1141

            # Get concatenated weight and bias tensors
1142
            unfused_weights = [getattr(self, name) for name in self.weight_names]
1143
            if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
1144
1145
1146
                if self.fp8:
                    if len(unfused_weights) != 1:
                        raise RuntimeError(
1147
                            "Splitting QuantizedTensor into multiple params is not supported"
1148
1149
                        )
                else:
1150
                    unfused_weights = [w.dequantize() for w in unfused_weights]
1151
            weight_tensor = noop_cat(unfused_weights)
1152
            if self.use_bias:
1153
                bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
1154
            else:
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
                bias_tensor = None

            (
                input_quantizer,
                weight_quantizer,
                output_quantizer,
                grad_output_quantizer,
                grad_input_quantizer,
            ) = self._get_quantizers(fp8_output, fp8_grad)

            # Make sure weight tensor has correct quantizer
            # Note: Quantizer might have changed if quantization
            # recipe changed
            if weight_quantizer is not None and isinstance(weight_tensor, QuantizedTensor):
                weight_tensor._quantizer = weight_quantizer
1170

1171
1172
1173
1174
1175
1176
1177
1178
1179
            if torch.is_grad_enabled():
                linear_fn = _Linear.apply
                args = []
            else:
                linear_fn = _Linear.forward
                args = [None]
            args += (
                weight_tensor,
                inp,
1180
                bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
1181
1182
1183
                is_first_microbatch,
                self.fp8,
                self.fp8_calibration,
1184
1185
1186
1187
1188
                input_quantizer,
                weight_quantizer,
                output_quantizer,
                grad_output_quantizer,
                grad_input_quantizer,
1189
                self.fuse_wgrad_accumulation,
1190
                is_cpu_offload_enabled(),
1191
1192
1193
1194
1195
1196
1197
                self.tp_group,
                self.tp_size,
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
                torch.is_grad_enabled(),
1198
1199
1200
1201
1202
1203
                self.ub_overlap_rs_fprop,
                self.ub_overlap_ag_dgrad,
                self.ub_overlap_ag_fprop,
                self.ub_overlap_rs_dgrad,
                self.ub_bulk_dgrad,
                self.ub_bulk_wgrad,
1204
                self.ub_name,
1205
                fp8_output,
1206
                self.fsdp_group,
1207
1208
                self,
                skip_fp8_weight_update,
1209
1210
1211
1212
1213
1214
1215
1216
            )
            out = linear_fn(*args)
        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        return out
1217
1218
1219
1220
1221
1222
1223
1224

    def _get_quantizers(self, fp8_output, fp8_grad):
        if not self.fp8:
            return [None] * 5
        grad_input_quantizer = None
        grad_output_quantizer = None
        output_quantizer = None
        input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
1225
        input_quantizer.internal = False
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
        weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
        weight_quantizer.internal = True
        if fp8_output:
            output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
        if torch.is_grad_enabled():
            grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
            grad_output_quantizer.internal = True
            if fp8_grad:
                grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
        return (
            input_quantizer,
            weight_quantizer,
            output_quantizer,
            grad_output_quantizer,
            grad_input_quantizer,
        )
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288

    def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
        """Customize quantizers based on current scaling recipe + linear."""
        assert (
            recipe.float8_current_scaling()
        ), "current scaling recipe quantizer customization here"
        if fwd:
            # set configs about amax epsilon and power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_INPUT
            ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_INPUT
            ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon
            # also set weight quantizer with same amax_epsilon & power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_WEIGHT
            ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_WEIGHT
            ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
            # paralle related
            if self.sequence_parallel and self.parallel_mode == "column":
                # customize input_quantizer with amax reduction TP group
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].with_amax_reduction = True
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].amax_reduction_group = self.tp_group
        else:
            # set grad_output_quantizer with amax epsilon and power_2_scale
            self.quantizers["scaling_bwd"][
                tex.FP8BwdTensors.GRAD_OUTPUT1
            ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
            self.quantizers["scaling_bwd"][
                tex.FP8BwdTensors.GRAD_OUTPUT1
            ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
            # parallel related
            if self.sequence_parallel and self.parallel_mode == "row":
                # customize grad_output_quantizer with amax reduction TP group
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].with_amax_reduction = True
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].amax_reduction_group = self.tp_group