linear.py 67.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Linear API"""
6
from typing import Callable, Dict, Optional, Tuple, Union, List
7
8
from functools import reduce
from operator import mul as multiply_op
9
import warnings
10
11
12

import torch

13
import transformer_engine_torch as tex
14

15
from transformer_engine.common.recipe import Recipe
16
from transformer_engine.pytorch import torch_version
17

18
from .base import (
19
20
    fill_userbuffers_buffer_for_all_gather,
    get_dummy_wgrad,
21
    get_ub,
22
    get_workspace,
23
24
25
26
27
    TransformerEngineBaseModule,
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
28
from ._common import noop_cat, WeightGradStore
29
from ..fp8 import FP8GlobalStateManager
30
31
from ..utils import (
    cast_if_needed,
32
    clear_tensor_data,
33
    divide,
34
    init_method_constant,
35
36
    requires_grad,
    needs_quantized_gemm,
37
    assert_dim_for_fp8_exec,
38
39
    nvtx_range_pop,
    nvtx_range_push,
40
41
42
43
44
)
from ..distributed import (
    set_tensor_model_parallel_attributes,
    get_distributed_world_size,
    allreduce,
45
    symmetric_all_reduce,
46
47
    reduce_scatter_along_first_dim,
    gather_along_first_dim,
48
    is_fp8_activation_recompute_enabled,
49
    in_fp8_activation_recompute_phase,
50
51
    _fsdp_scatter_tensors,
    _fsdp_gather_tensors,
52
53
)
from ..cpp_extensions import (
54
    general_gemm,
55
)
56
from ..constants import GemmParallelModes, dist_group_type
57
from ..jit import no_torch_dynamo
58
from ..graph import is_graph_capturing
59
60
from ..tensor.quantized_tensor import (
    QuantizedTensor,
61
    QuantizedTensorBase,
62
63
64
65
    Quantizer,
    prepare_for_saving,
    restore_from_saved,
)
66
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
67
from ..tensor.mxfp8_tensor import MXFP8Quantizer
68
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
69
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
70
from ..export import is_in_onnx_export_mode, assert_warmed_up
71
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
72
73
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
74

75
76
77
78
79
80
81
82
83
84
85
__all__ = ["Linear"]


class _Linear(torch.autograd.Function):
    """Linear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
86
        weight: torch.Tensor,
87
        inp: torch.Tensor,
88
        bias: Optional[torch.Tensor],
89
90
91
        is_first_microbatch: Union[bool, None],
        fp8: bool,
        fp8_calibration: bool,
92
        wgrad_store: WeightGradStore,
93
94
95
96
        input_quantizer: Optional[Quantizer],
        weight_quantizer: Optional[Quantizer],
        output_quantizer: Optional[Quantizer],
        grad_input_quantizer: Optional[Quantizer],
97
98
        grad_weight_quantizer: Optional[Quantizer],
        grad_output_quantizer: Optional[Quantizer],
99
        fuse_wgrad_accumulation: bool,
100
        cpu_offloading: bool,
101
102
103
104
105
106
107
        tp_group: Union[dist_group_type, None],
        tp_size: int,
        sequence_parallel: bool,
        tensor_parallel: bool,
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
        is_grad_enabled: bool,
108
109
110
111
112
113
        ub_overlap_rs_fprop: bool,
        ub_overlap_ag_dgrad: bool,
        ub_overlap_ag_fprop: bool,
        ub_overlap_rs_dgrad: bool,
        ub_bulk_dgrad: bool,
        ub_bulk_wgrad: bool,
114
        ub_name: str,
115
        fp8_output: bool,  # pylint: disable=unused-argument
116
        fsdp_group: Union[dist_group_type, None],
117
118
        module: torch.nn.Module,
        skip_fp8_weight_update: bool,
119
        symmetric_ar_type: str,
120
        debug: Optional[bool] = False,
121
    ) -> torch.Tensor:
122
        # pylint: disable=missing-function-docstring
123

124
125
126
127
128
        # NVTX label for profiling
        nvtx_label = "transformer_engine._Linear.forward"
        if ub_name is not None:
            nvtx_label = f"{nvtx_label}.{ub_name}"

129
        # Make sure input dimensions are compatible
130
        out_features, in_features = weight.shape
131
        assert inp.shape[-1] == in_features, "GEMM not possible"
132

133
        # Configure tensor-parallel communication
134
        tp_world_size = get_distributed_world_size(tp_group)
135
136
137
138
        backward_needs_input = is_grad_enabled and weight.requires_grad
        with_input_all_gather_nccl = (
            parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
        )
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

        # Configure Userbuffers communication (comm+GEMM overlap)
        ub_obj = None
        ub_type = None
        if ub_overlap_rs_fprop:
            ub_obj = get_ub(ub_name + "_fprop")
            ub_type = tex.CommOverlapType.RS
        elif ub_overlap_ag_fprop:
            ub_obj = get_ub(ub_name + "_fprop")
            ub_type = tex.CommOverlapType.AG

        # ------------------------------------------------------
        # Prepare input tensor
        # Note: Cast to expected dtype and perform tensor-parallel communication
        # ------------------------------------------------------
        nvtx_range_push(f"{nvtx_label}.input_cast_comm")
        inputmat = inp  # Input tensor to save for backward (maybe sharded)
        inputmat_total = None  # Input tensor to pass to GEMM (gathered)
        own_quantized_input = False
158
        if fp8:
159
            assert_dim_for_fp8_exec(inputmat, weight)
160
161
162
163
164
165
        if with_input_all_gather_nccl or ub_overlap_ag_fprop:  # All-gather input tensor

            # Cast local input tensor if needed
            if fp8 or debug:
                if input_quantizer is None:
                    raise ValueError("Missing quantizer for input tensor")
166
                if not isinstance(inputmat, QuantizedTensorBase):
167
168
169
170
171
172
173
174
                    input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
                    if isinstance(
                        input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
                    ):
                        # All-gather is not supported with FP8 column-wise data
                        input_quantizer.set_usage(columnwise=False)
                    inputmat = input_quantizer(inputmat)
                    own_quantized_input = True
175
            else:
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
                inputmat = cast_if_needed(inp, activation_dtype)  # Cast for AMP

            # Initialize gathered input tensor
            quantizer = None
            if fp8 or debug:
                quantizer = input_quantizer
                quantizer.set_usage(rowwise=True, columnwise=False)
            if with_input_all_gather_nccl:  # Perform NCCL all-gather
                inputmat_total, _ = gather_along_first_dim(
                    inputmat,
                    tp_group,
                    quantizer=quantizer,
                )
            elif ub_overlap_ag_fprop:  # Initialize Userbuffers all-gather
                inputmat_total, _ = fill_userbuffers_buffer_for_all_gather(
                    ub_obj,
                    inputmat,
                    quantizer,
                    tp_group,
                )

        else:  # Do not all-gather input tensor
            if fp8 or debug:
                if isinstance(inputmat, QuantizedTensorBase):
                    inputmat.update_usage(rowwise_usage=True)
201
                else:
202
203
204
                    if input_quantizer is None:
                        raise ValueError("Missing quantizer for input tensor")
                    input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
205
                    inputmat = input_quantizer(inputmat)
206
                    own_quantized_input = True
207
            else:
208
209
                inputmat = cast_if_needed(inp, activation_dtype)  # Cast for AMP
            inputmat_total = inputmat
210
        nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
211
212
213
        # ------------------------------------------------------
        # Input tensor is ready for GEMM...
        # ------------------------------------------------------
214

215
216
217
        # ------------------------------------------------------
        # Prepare weight tensor
        # ------------------------------------------------------
218
219
        weightmat = weight
        if fp8 or debug:
220
221
222
223
224
225
226
227
228
            # Configure quantizer
            if weight_quantizer is not None:
                columnwise_usage = is_grad_enabled and inp.requires_grad
                if not columnwise_usage:
                    columnwise_usage = (
                        is_fp8_activation_recompute_enabled()
                        and not in_fp8_activation_recompute_phase()
                    )
                weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
229
230

            # Get quantized weight
231
232
233
234
235
236
237
238
            update_workspace = is_first_microbatch is None or is_first_microbatch
            weightmat = module.get_weight_workspace(
                tensor=weight,
                quantizer=weight_quantizer,
                cache_name=(None if is_first_microbatch is None else "weight"),
                update_workspace=update_workspace,
                skip_update_flag=skip_fp8_weight_update,
                fsdp_group=fsdp_group,
239
                workspace_dtype=activation_dtype,
240
            )
241
242
            weightmat.update_usage(rowwise_usage=True)

243
        else:
244
245
246
247
            weightmat = cast_if_needed(weightmat, activation_dtype)  # Cast for AMP
        # ------------------------------------------------------
        # Weight tensor is ready for GEMM...
        # ------------------------------------------------------
248
249
250

        # Cast bias to expected dtype
        bias_dtype = activation_dtype
251
        if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32:
252
            # cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
253
254
255
256
257
258
259
260
261
262
            bias_dtype = torch.bfloat16
        bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias

        # Calibrate quantizers if needed
        if not fp8 and fp8_calibration:
            if input_quantizer is not None:
                input_quantizer.calibrate(inputmat_total)
            if weight_quantizer is not None:
                weight_quantizer.calibrate(weight)

263
264
        # Choose whether to use GEMM kernel with split accumulator
        use_split_accumulator = _2X_ACC_FPROP
265
266
267
        if fp8:
            recipe = FP8GlobalStateManager.get_fp8_recipe()
            if hasattr(recipe, "fp8_gemm_fprop"):
268
269
270
271
272
                use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator

        # Configure output quantizer
        if output_quantizer is not None:
            output_quantizer.set_usage(rowwise=True, columnwise=False)
273

274
275
276
277
278
279
280
281
282
283
284
285
286
287
        # Output buffer for Userbuffers reduce-scatter
        reduce_scatter_out = None
        if ub_overlap_rs_fprop:
            out_shape = list(inp.shape)
            out_shape[0] //= tp_world_size
            out_shape[-1] = out_features
            reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device)

        # ------------------------------------------------------
        # Forward GEMM
        # Note: y = x * w^T
        # ------------------------------------------------------
        nvtx_range_push(f"{nvtx_label}.gemm")
        gemm_out, *_, reduce_scatter_out = general_gemm(
288
289
290
291
            weightmat,
            inputmat_total,
            get_workspace(),
            quantization_params=output_quantizer,
292
            out_dtype=activation_dtype,
293
            bias=bias,
294
            use_split_accumulator=use_split_accumulator,
295
296
            ub=ub_obj,
            ub_type=ub_type,
297
            extra_output=reduce_scatter_out,
298
        )
299
        nvtx_range_pop(f"{nvtx_label}.gemm")
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
        # ------------------------------------------------------
        # Finished forward GEMM...
        # ------------------------------------------------------

        # ------------------------------------------------------
        # Prepare output tensor
        # Note: Perform tensor-parallel communication
        # ------------------------------------------------------
        out = None
        if ub_overlap_rs_fprop:
            out = reduce_scatter_out
        elif parallel_mode == "row" and tp_size > 1:
            nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
            out = gemm_out
            if sequence_parallel:
                out, _ = reduce_scatter_along_first_dim(out, tp_group)
            elif tensor_parallel:
                if symmetric_ar_type is not None:
                    out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
                else:
                    out, _ = allreduce(out, tp_group)
            nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
        else:
            out = gemm_out
        # ------------------------------------------------------
        # Output tensor is ready to return...
        # ------------------------------------------------------

        # ------------------------------------------------------
        # Cache state for backward pass
        # ------------------------------------------------------
331
332

        if is_grad_enabled:
333
            ctx.weight_quantizer = weight_quantizer
334
            saved_inputmat = None
335
336
337
338
339

            ctx.backward_input_needs_gather = (
                weight.requires_grad and parallel_mode == "column" and sequence_parallel
            )

340
            if backward_needs_input:
341
                if own_quantized_input and isinstance(inputmat, QuantizedTensorBase):
342
343
344
                    # For sequence parallel in vanilla FP8, rowwise data is
                    # to gather the input. For MXFP8, columnwise only data
                    # can be allgathered.
345
346
347
348
                    if (
                        isinstance(inputmat, (MXFP8TensorBase, Float8BlockwiseQTensorBase))
                        or not ctx.backward_input_needs_gather
                    ):
349
                        inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
350
                saved_inputmat = inputmat
351

352
353
            # Weight with column-wise usage is needed for dgrad GEMM.
            if inp.requires_grad:
354
                if isinstance(weightmat, QuantizedTensorBase):
355
356
                    weightmat.update_usage(columnwise_usage=True)

357
358
            if cpu_offloading and saved_inputmat is not None:
                mark_activation_offload(saved_inputmat)
359

360
361
            # Scatter intermediate/activation tensors saved for the backward pass
            # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights
362
            nvtx_range_push(f"{nvtx_label}.fsdp_scatter")
363
364
365
            ctx.fsdp_group = fsdp_group
            ctx.fsdp_shapes = _fsdp_scatter_tensors(
                fsdp_group,
366
                saved_inputmat,
367
                weightmat if fp8 and not isinstance(weight, QuantizedTensorBase) else None,
368
            )
369
            nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
370

371
372
373
374
375
376
377
378
379
380
381
            if cpu_offloading:
                ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")

                if ctx.grad_added_to_main_grad:
                    # If you are passing torch.nn.Parameter through the Torch hooks, you will
                    # get back torch.Tensor. Torch rips off the Parameter wrapper.
                    # You need to preserve the weight object to have all the attributes user
                    # sets for the weights. Because of this, it is not recommended to offload
                    # weights if weights are externally touched outside this module
                    ctx.weight_object = weight

382
383
            # TODO(ksivamani): Check memory usage
            tensors_to_save, tensor_objects = prepare_for_saving(
384
                saved_inputmat,
385
                weightmat,
386
                weight,
387
                bias,
388
            )
389
390
            ctx.save_for_backward(*tensors_to_save)
            ctx.tensor_objects = tensor_objects
391

392
393
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
394
            ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
395
396
            ctx.input_quantizer = input_quantizer
            ctx.grad_input_quantizer = grad_input_quantizer
397
398
            ctx.grad_weight_quantizer = grad_weight_quantizer
            ctx.grad_output_quantizer = grad_output_quantizer
399
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
400
            if fuse_wgrad_accumulation and weight.requires_grad:
401
402
403
404
405
406
407
408
                # This check is needed to ensure that main_grad is not created
                # during the forward pass when using MCore FSDP as it creates
                # the main_grad buffer lazily before backprop
                if hasattr(weight, "__fsdp_param__"):
                    # MCore FSDP creates main_grad lazily before backward
                    ctx.main_grad_func = weight.get_main_grad
                else:
                    ctx.main_grad_func = lambda: weight.main_grad
409

410
            ctx.debug = debug
411
            ctx.cpu_offloading = cpu_offloading
412
            ctx.is_first_microbatch = is_first_microbatch
413
            ctx.use_bias = bias is not None
414
415
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
416
            ctx.inp_shape = inp.shape
417
418
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
419
420
421
422
            ctx.ub_overlap_ag = ub_overlap_ag_dgrad
            ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
            ctx.ub_bulk_dgrad = ub_bulk_dgrad
            ctx.ub_bulk_wgrad = ub_bulk_wgrad
423
            ctx.ub_name = ub_name
424
425
            ctx.tp_size = tp_size
            ctx.requires_dgrad = inp.requires_grad
426
            ctx.requires_wgrad = weight.requires_grad
427
            ctx.reduce_and_update_bwd_fp8_tensors = False
428

429
            ctx.owns_input = saved_inputmat is not inp
430
            if ctx.fp8 and requires_grad(inp, weight, bias):
431
432
433
434
                _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
                ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
                if in_fp8_activation_recompute_phase():
                    FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
435
            ctx.wgrad_store = wgrad_store
436

437
438
439
        # ------------------------------------------------------
        # Cached state for backward pass is ready...
        # ------------------------------------------------------
440

441
        return out
442
443

    @staticmethod
444
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
445
        # pylint: disable=missing-function-docstring
446

447
448
449
450
451
        # NVTX label for profiling
        nvtx_label = "transformer_engine._Linear.backward"
        if ctx.ub_name is not None:
            nvtx_label = f"{nvtx_label}.{ctx.ub_name}"

452
        with torch.cuda.nvtx.range("_Linear_backward"):
453
454
455
456
            saved_tensors = ctx.saved_tensors
            inputmat, weight_fp8, weight, bias = (  # pylint: disable=unbalanced-tuple-unpacking
                restore_from_saved(ctx.tensor_objects, saved_tensors)
            )
457

458
459
460
            # Delete the references to tensor objects once they've been consumed
            # by the `restore_from_saved` method to construct back the actual tensors.
            ctx.tensor_objects = None
461
462
463

            # Since main_grad can be modified inplace, it should not be a part of saved_tensors
            main_grad = (
464
                ctx.main_grad_func()
465
466
467
468
                if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad
                else None
            )

469
470
471
472
473
            if ctx.cpu_offloading:
                if ctx.grad_added_to_main_grad:
                    weight = ctx.weight_object
                if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
                    weight.main_grad = main_grad
474

475
476
477
            # Gather intermediate/activation tensors if needed
            # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
            #       shards/unshards the base weights so we don't do it ourselves
478
            nvtx_range_push(f"{nvtx_label}.fsdp_gather")
479
480
481
482
            _fsdp_gather_tensors(
                ctx.fsdp_group,
                ctx.fsdp_shapes,
                inputmat,
483
                weight_fp8,
484
            )
485
            nvtx_range_pop(f"{nvtx_label}.fsdp_gather")
486

487
            # Configure Userbuffers communication (comm+GEMM overlap)
488
            ctx.ub_obj_gradout = None
489
            ub_obj_dgrad = None
490
            ub_obj_wgrad = None
491
492
            ub_type_dgrad = None
            ub_type_wgrad = None
493
            dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
494
            if ctx.ub_overlap_ag:
495
                # Overlap grad_output all-gather with dgrad compute
496
                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
497
498
                ub_obj_dgrad = ctx.ub_obj_gradout
                ub_type_dgrad = tex.CommOverlapType.AG
499
500
501
            elif ctx.ub_overlap_rs_dgrad:
                # Overlap dgrad reduce-scatter with dgrad compute
                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
502
503
                ub_obj_dgrad = ctx.ub_obj_gradout
                ub_type_dgrad = tex.CommOverlapType.RS
504
505
506
507
            else:
                if ctx.ub_bulk_dgrad:
                    # Overlap inputmat all-gather with dgrad compute
                    ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
508
509
                    ub_obj_dgrad = ctx.ub_obj_gradout
                    ub_type_dgrad = tex.CommOverlapType.AG
510
511
512
                if ctx.ub_bulk_wgrad:
                    # Overlap dgrad reduce-scatter with wgrad compute
                    ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
513
                    ub_type_wgrad = tex.CommOverlapType.RS
514
515
516
517
518
519
520
521

            # --------------------------------------------------
            # Prepare grad output tensor
            # Note: Cast to expected dtype and perform tensor-parallel communication
            # --------------------------------------------------

            # Unmodified grad output tensor
            grad_output_arg = grad_output
522

523
524
525
526
            # Configure quantizer for grad output tensor
            # Note: dgrad GEMM requires row-wise usage, wgrad GEMM
            # requires column-wise usage
            if ctx.grad_output_quantizer is not None:
527
528
529
530
531
532
533
                quantizer = ctx.grad_output_quantizer
                quantizer.set_usage(rowwise=True, columnwise=True)
                if ctx.ub_overlap_ag:
                    # Userbuffers only supports communication for one
                    # tensor usage at a time. Configure quantizer with
                    # usage for only dgrad GEMM.
                    quantizer.set_usage(columnwise=False)
534

535
            # Prepare grad output tensor
536
            nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
537
538
539
540
            (
                grad_output,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
541
542
543
544
                ctx,
                grad_output,
                ctx.parallel_mode == "row",
                ctx.grad_output_quantizer,
545
            )
546
            nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
547

548
549
550
551
552
553
554
555
556
557
            # --------------------------------------------------
            # Grad output tensor is ready for computing grad input...
            # --------------------------------------------------

            # --------------------------------------------------
            # Prepare input tensor
            # Note: Input tensor is needed for wgrad GEMM.
            # Tensor-parallel communication is overlapped with dgrad
            # GEMM.
            # --------------------------------------------------
558
            inputmat_total = None
559
            inputmat_total_work = None
560
            if ctx.backward_input_needs_gather:
561
                quantizer = None
562
                if ctx.fp8 or ctx.debug:
563
                    quantizer = ctx.input_quantizer
564
565
566
567
568
569
                    if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
                        # If data is in FP8, we compute FP8 transposes manually
                        quantizer.set_usage(rowwise=True, columnwise=False)
                    else:
                        # wgrad GEMM requires input with column-wise usage
                        quantizer.set_usage(rowwise=False, columnwise=True)
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
                if ctx.ub_bulk_dgrad:
                    inputmat_total, _ = fill_userbuffers_buffer_for_all_gather(
                        ub_obj_dgrad,
                        inputmat,
                        quantizer,
                        ctx.tp_group,
                    )
                else:
                    nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
                    inputmat_total, inputmat_total_work = gather_along_first_dim(
                        inputmat,
                        ctx.tp_group,
                        async_op=True,
                        quantizer=quantizer,
                    )
                    nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
586
587
            else:
                inputmat_total = inputmat
588
589
590
            # --------------------------------------------------
            # Input tensor is ready for computing grad weight...
            # --------------------------------------------------
591

592
            # --------------------------------------------------
593
            # Compute grad input tensor
594
595
            # --------------------------------------------------

596
597
            dgrad = None
            dgrad_work = None
598
            if ctx.requires_dgrad:
599

600
601
602
603
604
605
606
607
                # Make sure required data is available
                if isinstance(grad_output, QuantizedTensorBase):
                    grad_output.update_usage(rowwise_usage=True)
                if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase):
                    weight_fp8.update_usage(columnwise_usage=True)

                # Choose whether to use GEMM kernel with split accumulator
                use_split_accumulator = _2X_ACC_DGRAD
608
609
610
                if ctx.fp8:
                    recipe = ctx.fp8_recipe
                    if hasattr(recipe, "fp8_gemm_dgrad"):
611
                        use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
612

613
614
615
616
617
618
619
620
621
622
                # Update grad input quantizer
                if ctx.grad_input_quantizer is not None:
                    ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)

                # Output buffers for Userbuffers reduce-scatter
                gemm_out = None
                reduce_scatter_out = None
                if ctx.ub_overlap_rs_dgrad:
                    reduce_scatter_out = torch.empty(
                        dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device
623
                    )
624
625
                elif ctx.ub_bulk_wgrad:
                    gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False)
626

627
628
629
630
                # dgrad GEMM
                # Note: dx = dy * w
                nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
                gemm_out, *_, reduce_scatter_out = general_gemm(
631
632
633
634
635
636
                    weight_fp8,
                    grad_output,
                    get_workspace(),
                    layout="NN",
                    grad=True,
                    quantization_params=ctx.grad_input_quantizer,
637
                    out=gemm_out,
638
                    out_dtype=ctx.activation_dtype,
639
                    use_split_accumulator=use_split_accumulator,
640
641
                    ub=ub_obj_dgrad,
                    ub_type=ub_type_dgrad,
642
                    extra_output=reduce_scatter_out,
643
644
                    bulk_overlap=ctx.ub_bulk_dgrad,
                )
645
                nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
646

647
648
                # Prepare grad input tensor
                # Note: Perform tensor-parallel communication
649
                if ctx.ub_overlap_rs_dgrad:
650
651
652
653
                    dgrad = reduce_scatter_out
                elif ctx.ub_bulk_wgrad:
                    dgrad = ub_obj_wgrad.get_buffer(local_chunk=True)
                elif ctx.parallel_mode == "column" and ctx.tp_size > 1:
654
                    nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
655
                    dgrad = gemm_out
656
657
658
659
660
                    if ctx.sequence_parallel:
                        dgrad, dgrad_work = reduce_scatter_along_first_dim(
                            dgrad,
                            ctx.tp_group,
                            async_op=True,
661
                        )
662
                    else:
663
                        dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
664
                    nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
665
666
667
668
669
670
671
672
673
674
                else:
                    dgrad = gemm_out

            # --------------------------------------------------
            # Grad input tensor has been computed...
            # --------------------------------------------------

            # --------------------------------------------------
            # Compute grad weight
            # --------------------------------------------------
675

676
677
            wgrad = None
            if ctx.requires_wgrad:
678

679
680
681
                # Prepare input tensor
                # Note: Synchronize tensor-parallel communication and
                # make sure required data is available
682
683
684
                if inputmat_total_work is not None:
                    inputmat_total_work.wait()
                    inputmat_total_work = None
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
                if ctx.fp8 or ctx.debug:
                    if isinstance(inputmat_total, QuantizedTensorBase):
                        inputmat_total.update_usage(columnwise_usage=True)
                    else:
                        ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
                        inputmat_total = ctx.input_quantizer(inputmat_total)

                # Prepare grad output tensor
                # Note: Synchronize tensor-parallel communication and
                # make sure required data is available
                if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
                    # UB does not support overlapping grad output
                    # all-gather with wgrad GEMM. Also, we can't
                    # convert row-scaled MXFP8 to column-scaled, so we
                    # can't reuse the grad output that was gathered
700
701
                    # for the dgrad GEMM. We work around by explicitly
                    # overlapping the NCCL operation with the dgrad GEMM.
702
                    ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
703
704
705
706
707
708
709
710
711
712
713
714
715
716
                    # Get the communication stream from the dgrad GEMM and set it as the current torch stream
                    dgrad_comm_stream = ub_obj_dgrad.get_communication_stream()
                    with torch.cuda.stream(dgrad_comm_stream):
                        # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
                        # This ensures that we don't start until all communication for the dgrad GEMM is complete
                        grad_output, grad_output_work = gather_along_first_dim(
                            grad_output_arg,
                            ctx.tp_group,
                            async_op=True,
                            quantizer=ctx.grad_output_quantizer,
                        )
                    # Synchronize with the main stream
                    grad_output_work.wait()

717
718
719
720
721
722
                if ctx.fp8 or ctx.debug:
                    if isinstance(grad_output, QuantizedTensorBase):
                        grad_output.update_usage(columnwise_usage=True)
                    else:
                        ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
                        grad_output = ctx.grad_output_quantizer(grad_output)
723

724
725
726
727
728
729
730
                # Figure out whether to use split accumulator
                use_split_accumulator = _2X_ACC_WGRAD
                if ctx.fp8:
                    recipe = ctx.fp8_recipe
                    if hasattr(recipe, "fp8_gemm_wgrad"):
                        use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator

731
732
733
734
735
736
737
738
739
                # Figure out whether to output wgrad GEMM directly into main grad
                if ctx.is_first_microbatch is not None:
                    accumulate_wgrad_into_param_main_grad = (
                        ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                    )
                else:
                    accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation

                # Output buffer for overlapping FP8 grad input
740
                # reduce-scatter with wgrad GEMM
741
                reduce_scatter_out = None
742
                if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
743
744
                    reduce_scatter_out = torch.empty(
                        dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device
745
746
                    )

747
748
749
750
                # Arguments to include in wgrad GEMM closure
                wgrad_gemm_kwargs = {
                    "workspace": get_workspace(),
                    "out_dtype": (
751
752
                        main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
                    ),
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
                    "quantization_params": ctx.grad_weight_quantizer,
                    "accumulate": accumulate_wgrad_into_param_main_grad,
                    "layout": "NT",
                    "out": main_grad if ctx.fuse_wgrad_accumulation else None,
                    "bias": (bias if (grad_bias is None and not ctx.fp8) else None),
                    "use_split_accumulator": use_split_accumulator,
                    "grad": True,
                    "ub": ub_obj_wgrad,
                    "ub_type": ub_type_wgrad,
                    "extra_output": reduce_scatter_out,
                    "bulk_overlap": ctx.ub_bulk_wgrad,
                }

                def wgrad_gemm(
                    x: torch.Tensor,
                    dy: torch.Tensor,
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
                    """Perform wgrad GEMM: dw = dy^T * x

                    May be fused with bgrad computation.

                    May be called outside of this function to enable
                    some advanced communication/compute overlapping.

                    """
                    nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
                    dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs)
                    nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
                    return dw, db

                # Choose whether to call wgrad GEMM now or delay
784
                if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
785
786
787
788
789
790
791
792
793
794
795
                    if (
                        wgrad_gemm_kwargs["ub"] is not None
                        or wgrad_gemm_kwargs["ub_type"] is not None
                        or wgrad_gemm_kwargs["extra_output"] is not None
                        or wgrad_gemm_kwargs["bulk_overlap"]
                    ):
                        raise NotImplementedError(
                            "Delayed weight grad computation is not supported "
                            "with Userbuffers (tensor-parallel communication overlapping)"
                        )
                    ctx.wgrad_store.put([inputmat_total, grad_output], wgrad_gemm)
796
797
                else:

798
799
800
801
                    # Call wgrad GEMM now
                    wgrad, grad_bias_ = wgrad_gemm(inputmat_total, grad_output)

                    # Update grad bias if needed
802
803
804
805
                    if grad_bias is None:
                        grad_bias = grad_bias_
                    del grad_bias_

806
                    # Deallocate input tensor if permitted
807
808
                    if ctx.owns_input:
                        clear_tensor_data(inputmat_total)
809

810
                # Update grad input if overlapping reduce-scatter with wgrad GEMM
811
812
                if ctx.ub_bulk_wgrad:
                    if ub_obj_wgrad.is_fp8_ubuf():
813
                        dgrad = reduce_scatter_out
814
                    else:
815
816
817
818
819
                        dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone()

            # --------------------------------------------------
            # Grad weight has been computed...
            # --------------------------------------------------
820

821
            # Don't return grad bias if not needed
822
823
824
            if not ctx.use_bias:
                grad_bias = None

825
            # Make sure all tensor-parallel communication is finished
826
827
828
829
830
831
832
833
            if inputmat_total_work is not None:
                inputmat_total_work.wait()
                inputmat_total_work = None
            if dgrad_work is not None:
                dgrad_work.wait()
                dgrad_work = None

        if ctx.requires_wgrad:
834
            # Handle custom DDP from mcore.
835
836
837
838
839
            if (
                ctx.fuse_wgrad_accumulation
                and weight is not None
                and hasattr(weight, "grad_added_to_main_grad")
            ):
840
                weight.grad_added_to_main_grad = True
841
                if getattr(weight, "zero_out_wgrad", False):
842
843
844
845
                    wgrad = get_dummy_wgrad(
                        list(weight.main_grad.shape),
                        weight.dtype,
                        zero=True,
846
                    )
847
                else:
848
849
850
                    wgrad = get_dummy_wgrad(
                        list(weight.main_grad.shape),
                        weight.dtype,
851
                    )
852
853
854
855
            elif ctx.fuse_wgrad_accumulation:
                wgrad = None
        else:
            wgrad = None
856

857
        # Update FP8 scaling factors if needed
858
        if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
859
            nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
860
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
861
            nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
862

863
        # Scatter fp8 weight buffers
864
        if ctx.fp8 and not isinstance(weight, QuantizedTensorBase):
865
            _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
866
        return (
867
            wgrad,
868
869
            dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
            grad_bias,
870
871
872
            None,  # is_first_microbatch
            None,  # fp8
            None,  # fp8_calibration
873
            None,  # wgrad_store
874
875
876
877
            None,  # input_quantizer
            None,  # weight_quantizer
            None,  # output_quantizer
            None,  # grad_input_quantizer
878
879
            None,  # grad_weight_quantizer
            None,  # grad_output_quantizer
880
881
882
883
884
885
886
887
888
            None,  # fuse_wgrad_accumulation
            None,  # cpu_offloading
            None,  # tp_group
            None,  # tp_size
            None,  # sequence_parallel
            None,  # tensor_parallel
            None,  # activation_dtype
            None,  # parallel_mode
            None,  # is_grad_enabled
889
890
891
892
893
894
            None,  # ub_overlap_rs_fprop
            None,  # ub_overlap_ag_dgrad
            None,  # ub_overlap_ag_fprop
            None,  # ub_overlap_rs_dgrad
            None,  # ub_bulk_dgrad
            None,  # ub_bulk_wgrad
895
            None,  # ub_name
896
            None,  # fp8_output
897
            None,  # fsdp_group
898
899
            None,  # module
            None,  # skip_fp8_weight_update
900
            None,  # symmetric_ar_type
901
            None,  # debug
902
903
904
905
        )


class Linear(TransformerEngineBaseModule):
906
    """Applies a linear transformation to the incoming data :math:`y = xA^T + b`
907
908
909
910
911
912
913
914
915
916
917
918
919
920

    On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
921
    get_rng_state_tracker : Callable, default = `None`
922
                 used to get the random number generator state tracker for initializing weights.
923
924
    rng_tracker_name : str, default = `None`
                 the param passed to get_rng_state_tracker to get the specific rng tracker.
cyanguwa's avatar
cyanguwa committed
925
    parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
926
927
928
929
930
931
932
                      Configuration for splitting the weight and bias tensors along dim 0 into
                      multiple PyTorch parameters. If a list or tuple of strings is provided,
                      they are used to make the names of equally-sized parameters. If a dict
                      (preferably an OrderedDict) is provided, the keys are used as names and
                      values as split sizes along dim 0. The resulting parameters will have
                      names that end in `_weight` or `_bias`, so trailing underscores are
                      stripped from any provided names.
933
    device : Union[torch.device, str], default = "cuda"
934
          The device on which the parameters of the model will be allocated. It is the user's
935
936
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
937
938
    name: str, default = `None`
        name of the module, currently used for debugging purposes.
939
940
941
942
943
944
945
946
947
948
949
950
951

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
952
    parallel_mode : {None, 'column', 'row'}, default = `None`
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
970
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
971
972
973
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
974
975
976
977
    delay_wgrad_compute : bool, default = `False`
                         Whether or not to delay weight gradient computation. If set to `True`,
                         it's the user's responsibility to call `module.backward_dw` to compute
                         weight gradients.
978
979
980
981
982
    symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
                   Type of symmetric memory all-reduce to use during the forward pass.
                   This can help in latency bound communication situations.
                   Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
                   is used.
983
984
985
986
987
988
989
990
991
992
993
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
994
        rng_tracker_name: Optional[str] = None,
995
996
997
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
998
        params_dtype: Optional[torch.dtype] = None,
999
        parallel_mode: Optional[str] = None,
cyanguwa's avatar
cyanguwa committed
1000
        parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
1001
        device: Union[torch.device, str] = "cuda",
1002
        ub_overlap_ag: bool = False,
1003
        ub_overlap_rs: bool = False,
1004
        ub_overlap_rs_dgrad: bool = False,
1005
1006
        ub_bulk_dgrad: bool = False,
        ub_bulk_wgrad: bool = False,
1007
        ub_name: Optional[str] = None,
1008
        delay_wgrad_compute: bool = False,
1009
        symmetric_ar_type: Optional[str] = None,
1010
        name: Optional[str] = None,
1011
1012
    ) -> None:
        super().__init__()
1013
1014

        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
1015
1016
1017
1018
1019
1020
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
        self.apply_bias = bias and not return_bias
1021
1022
        self.get_rng_state_tracker = get_rng_state_tracker
        self.rng_tracker_name = rng_tracker_name
1023
        self.symmetric_ar_type = symmetric_ar_type
1024
1025
1026
1027
        self.name = name

        if TEDebugState.debug_enabled:
            self._turn_off_unsupported_features_in_debug()  # turn off userbuffers
1028

1029
1030
        self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)

1031
1032
        if device == "meta":
            assert parameters_split is None, "Cannot split module parameters on 'meta' device."
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

1054
        # Column parallel TP overlap options
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
        self.ub_overlap_ag_fprop = (
            self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_ag
        )
        self.ub_overlap_rs_dgrad = (
            self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_rs_dgrad
        )
        self.ub_bulk_dgrad = (
            self.parallel_mode == "column"
            and self.sequence_parallel
            and ub_bulk_dgrad
            and not self.ub_overlap_rs_dgrad
        )
        self.ub_bulk_wgrad = (
            self.parallel_mode == "column"
            and self.sequence_parallel
            and ub_bulk_wgrad
            and not self.ub_overlap_rs_dgrad
        )
1073
1074

        # Row parallel TP overlap options
1075
1076
1077
1078
1079
1080
        self.ub_overlap_rs_fprop = (
            self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_rs
        )
        self.ub_overlap_ag_dgrad = (
            self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_ag
        )
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094

        if any(
            [
                self.ub_overlap_rs_fprop,
                self.ub_overlap_ag_dgrad,
                self.ub_overlap_ag_fprop,
                self.ub_overlap_rs_dgrad,
                self.ub_bulk_dgrad,
                self.ub_bulk_wgrad,
            ]
        ):
            assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized."
        self.ub_name = ub_name

1095
1096
1097
1098
1099
1100
1101
        if self.symmetric_ar_type is not None:
            assert torch_version() >= (
                2,
                7,
                0,
            ), "Torch version must be at least 2.7 to use symmetric memory"

1102
1103
1104
        # Initialize params in FP8
        with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()

1105
1106
1107
1108
1109
1110
1111
1112
        # Contiguous buffers for params
        weight_tensor = torch.empty(
            self.out_features,
            self.in_features,
            device=device,
            dtype=params_dtype,
        )
        bias_tensor = None
1113
        if self.use_bias:
1114
1115
1116
1117
1118
            bias_tensor = torch.empty(
                self.out_features,
                device=device,
                dtype=params_dtype,
            )
1119

1120
1121
1122
1123
        # Configure parameter splits
        self.weight_names = []
        self.bias_names = []
        self.parameter_split_sizes = []
1124
        if parameters_split is None:
1125
1126
1127
1128
1129
1130
            # Split into a single parameter by default
            self.weight_names = ["weight"]
            self.bias_names = ["bias"]
            self.parameter_split_sizes = [out_features]
        elif not parameters_split:
            raise ValueError("Cannot split weight buffer into 0 parameters")
cyanguwa's avatar
cyanguwa committed
1131
        elif isinstance(parameters_split, dict):
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
            # Split parameters with provided sizes
            for name, split_size in parameters_split.items():
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
        elif all(isinstance(name, str) for name in parameters_split):
            # Split parameters evenly
            split_size = out_features // len(parameters_split)
            for name in parameters_split:
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
cyanguwa's avatar
cyanguwa committed
1144
        else:
1145
            raise TypeError("Invalid configuration for parameters split")
1146

1147
1148
1149
1150
1151
1152
        # Make sure parameter splits are valid
        if sum(self.parameter_split_sizes) != out_features:
            raise ValueError(
                f"Trying to split weight buffer ({out_features=}) "
                f"with split sizes {self.parameter_split_sizes}"
            )
1153

1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
        # Adjust parameter splits for tensor-parallel distribution
        if self.parallel_mode == "column":
            for i, size in enumerate(self.parameter_split_sizes):
                if size % self.tp_size != 0:
                    raise RuntimeError(
                        f"Attempting to distribute a parameter with out_features={size} "
                        f"between {self.tp_size} tensor-parallel processes"
                    )
                self.parameter_split_sizes[i] = size // self.tp_size

1164
1165
1166
1167
1168
        # Construct weight parameters
        # Note: Register weights together so that they are adjacent to
        # each other in Linear.parameters(). This makes it more likely
        # that they will stay contiguous if the weights are
        # manipulated externally, e.g. by FSDP.
1169
1170
1171
1172
1173
1174
1175
1176
        offset = 0
        for i, split_size in enumerate(self.parameter_split_sizes):
            split_start = offset
            offset += split_size
            split_end = offset

            # Check if parameters are subviews of buffers
            is_subview = (split_start, split_end) != (0, self.out_features)
1177
            if is_subview and with_fp8_params:
1178
1179
1180
                raise RuntimeError(
                    "Splitting QuantizedTensor into multiple params is not supported"
                )
1181

1182
            # Construct weight parameter
1183
1184
1185
1186
1187
1188
1189
            self.register_parameter(
                self.weight_names[i],
                torch.nn.Parameter(weight_tensor[split_start:split_end]),
                init_fn=init_method,
                get_rng_state_tracker=get_rng_state_tracker,
                fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
            )
1190

1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
        # Construct bias parameters if needed
        if self.use_bias:
            offset = 0
            for i, split_size in enumerate(self.parameter_split_sizes):
                split_start = offset
                offset += split_size
                split_end = offset
                self.register_parameter(
                    self.bias_names[i],
                    torch.nn.Parameter(bias_tensor[split_start:split_end]),
                    init_fn=init_method_constant(0.0),
                )
        else:
            for name in self.bias_names:
                bias = torch.Tensor().to(dtype=params_dtype, device=device)
                setattr(self, name, bias)
cyanguwa's avatar
cyanguwa committed
1207

1208
        if with_fp8_params:
1209
1210
            self.init_fp8_metadata()

1211
        self.reset_parameters(defer_init=device == "meta")
1212

1213
1214
1215
1216
1217
1218
1219
        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.apply_bias:
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

1220
1221
1222
1223
1224
1225
1226
1227
    def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
        """Init scales and amaxes for fwd | bwd."""
        super().set_meta_tensor(fwd, recipe)

        # customize quantizers based on each recipe & layer configs
        recipe = FP8GlobalStateManager.get_fp8_recipe()
        if recipe.float8_current_scaling():
            self._customize_quantizers_float8_current_scaling(fwd, recipe)
1228
1229
        elif recipe.float8_block_scaling():
            self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
1230
1231
        # elif for other recipes (mxfp8, etc.)

1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
    def reset_parameters(self, defer_init=False):
        super().reset_parameters(defer_init=defer_init)

        if not defer_init:
            # Set parallelism attributes for linear weights
            for weight in self.weight_names:
                set_tensor_model_parallel_attributes(
                    tensor=getattr(self, weight),
                    is_parallel=True,
                    dim=1 if self.parallel_mode == "row" else 0,
                    stride=1,
                )

            # Set parallelism attributes for linear biases
            if self.use_bias:
                for bias in self.bias_names:
                    if self.parallel_mode == "row":
                        setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel)
                    elif self.parallel_mode == "column":
                        set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)

1253
    @no_torch_dynamo()
1254
1255
1256
1257
    def forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Optional[bool] = None,
1258
        fp8_output: Optional[bool] = False,
1259
        fp8_grad: Optional[bool] = False,
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """
1282
1283
1284
        if is_in_onnx_export_mode():
            return self.onnx_forward(inp, fp8_output)

1285
1286
1287
1288
        debug = TEDebugState.debug_enabled
        if debug:
            self._validate_name()

1289
1290
1291
1292
        if FP8GlobalStateManager.fp8_graph_capturing():
            skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
        else:
            skip_fp8_weight_update = None
1293
1294
1295
        if skip_fp8_weight_update is not None:
            is_first_microbatch = False

1296
1297
1298
1299
1300
1301
1302
        if self.ub_overlap_rs_fprop:
            if get_ub(self.ub_name + "_fprop").is_fp8_ubuf():
                fp8_output = True
        if self.ub_overlap_rs_dgrad:
            if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
                fp8_grad = True

1303
1304
        with self.prepare_forward(
            inp,
1305
            allow_non_contiguous=isinstance(inp, QuantizedTensor),
1306
        ) as inp:
1307

1308
            weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
1309

1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
            quantizers = (
                self._get_quantizers(fp8_output, fp8_grad)
                if not debug
                else self._get_debug_quantizers(fp8_output, fp8_grad)
            )
            if debug:
                if not any_feature_enabled(quantizers):
                    # If no feature is used, then run faster implementation with debug = False.
                    quantizers = self._get_quantizers(fp8_output, fp8_grad)
                    debug = False

                if isinstance(weight_tensor, QuantizedTensor):
                    raise RuntimeError("FP8 weights are not supported in debug mode.")

1324
1325
1326
1327
1328
            (
                input_quantizer,
                weight_quantizer,
                output_quantizer,
                grad_input_quantizer,
1329
1330
1331
                grad_weight_quantizer,
                grad_output_quantizer,
            ) = quantizers
1332

1333
1334
1335
1336
1337
1338
1339
1340
1341
            if torch.is_grad_enabled():
                linear_fn = _Linear.apply
                args = []
            else:
                linear_fn = _Linear.forward
                args = [None]
            args += (
                weight_tensor,
                inp,
1342
                bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
1343
1344
1345
                is_first_microbatch,
                self.fp8,
                self.fp8_calibration,
1346
                self.wgrad_store,
1347
1348
1349
1350
                input_quantizer,
                weight_quantizer,
                output_quantizer,
                grad_input_quantizer,
1351
1352
                grad_weight_quantizer,
                grad_output_quantizer,
1353
                self.fuse_wgrad_accumulation,
1354
                is_cpu_offload_enabled(),
1355
1356
1357
1358
1359
1360
1361
                self.tp_group,
                self.tp_size,
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
                torch.is_grad_enabled(),
1362
1363
1364
1365
1366
1367
                self.ub_overlap_rs_fprop,
                self.ub_overlap_ag_dgrad,
                self.ub_overlap_ag_fprop,
                self.ub_overlap_rs_dgrad,
                self.ub_bulk_dgrad,
                self.ub_bulk_wgrad,
1368
                self.ub_name,
1369
                fp8_output,
1370
                self.fsdp_group,
1371
1372
                self,
                skip_fp8_weight_update,
1373
                self.symmetric_ar_type,
1374
                debug,
1375
1376
1377
1378
1379
1380
1381
1382
            )
            out = linear_fn(*args)
        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        return out
1383
1384
1385

    def _get_quantizers(self, fp8_output, fp8_grad):
        if not self.fp8:
1386
            return [None] * 6
1387
        grad_input_quantizer = None
1388
        grad_weight_quantizer = None
1389
1390
1391
        grad_output_quantizer = None
        output_quantizer = None
        input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
1392
        input_quantizer.internal = True
1393
        (weight_quantizer,) = self._get_weight_quantizers()
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
        if fp8_output:
            output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
        if torch.is_grad_enabled():
            grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
            grad_output_quantizer.internal = True
            if fp8_grad:
                grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
        return (
            input_quantizer,
            weight_quantizer,
            output_quantizer,
            grad_input_quantizer,
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
            grad_weight_quantizer,
            grad_output_quantizer,
        )

    def _get_debug_quantizers(self, fp8_output, fp8_grad):
        original_quantizers = self._get_quantizers(fp8_output, fp8_grad)
        assert TEDebugState.debug_enabled
        from ...debug.pytorch.debug_quantization import DebugQuantizer

        names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
        return tuple(
            DebugQuantizer(self.name, name, q, self.tp_group)
            for name, q in zip(names, original_quantizers)
1419
        )
1420

1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
    def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
        """Get the weight tensors of the module."""
        unfused_weights = [getattr(self, name) for name in self.weight_names]
        if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
            if self.fp8:
                if len(unfused_weights) != 1:
                    raise RuntimeError(
                        "Splitting QuantizedTensor into multiple params is not supported"
                    )
            else:
                warnings.warn(
                    "You are using quantized weights without quantized compute. "
                    "Please make sure this is intentional."
                )
                unfused_weights = [w.dequantize() for w in unfused_weights]
        return unfused_weights

    def _get_weight_and_bias_tensors(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        # Get concatenated weight and bias tensors
        unfused_weights = self._get_weight_tensors()
        if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
            if self.fp8:
                if len(unfused_weights) != 1:
                    raise RuntimeError(
                        "Splitting QuantizedTensor into multiple params is not supported"
                    )
            else:
                warnings.warn(
                    "You are using quantized weights without quantized compute. "
                    "Please make sure this is intentional."
                )
                unfused_weights = [w.dequantize() for w in unfused_weights]

        weight_tensor = noop_cat(unfused_weights)
        if self.use_bias:
            bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
        else:
            bias_tensor = None

        return weight_tensor, bias_tensor

    def onnx_forward(
        self,
        inp: torch.Tensor,
        fp8_output: bool,
    ) -> torch.Tensor:
        """
        ONNX-compatible version of the forward function that provides numerical equivalence
        while only using operations that have defined ONNX symbolic translations.
        This simplified implementation is designed specifically for inference scenarios.
        """
        from ..export import onnx_gemm

        assert_warmed_up(self)
        assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export."
        weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
        (
            input_quantizer,
            weight_quantizer,
            output_quantizer,
            *_,
        ) = self._get_quantizers(fp8_output, False)
        inp_dtype = inp.dtype

        if input_quantizer is not None:
            inp_q = input_quantizer.onnx_quantize(inp)
            inp = input_quantizer.onnx_dequantize(inp_q)
            inp = inp.to(inp_dtype)

        if weight_quantizer is not None:
            weight_q = weight_quantizer.onnx_quantize(weight_tensor)
            weight_tensor = weight_quantizer.onnx_dequantize(weight_q)
        if bias_tensor is not None:
            bias_tensor = bias_tensor.to(inp_dtype)
        weight_tensor = weight_tensor.to(inp_dtype)

        if self.apply_bias:
            output = onnx_gemm(weight_tensor, inp, bias_tensor)
        else:
            output = onnx_gemm(weight_tensor, inp, None)

        if output_quantizer is not None:
            raise NotImplementedError("ONNX export of quantized output is not supported")

        if self.return_bias:
            return output, bias_tensor

        return output

1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
    def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
        """Customize quantizers based on current scaling recipe + linear."""
        assert (
            recipe.float8_current_scaling()
        ), "current scaling recipe quantizer customization here"
        if fwd:
            # set configs about amax epsilon and power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_INPUT
            ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_INPUT
            ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon
            # also set weight quantizer with same amax_epsilon & power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_WEIGHT
            ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
            self.quantizers["scaling_fwd"][
                tex.FP8FwdTensors.GEMM1_WEIGHT
            ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
            # paralle related
            if self.sequence_parallel and self.parallel_mode == "column":
                # customize input_quantizer with amax reduction TP group
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].with_amax_reduction = True
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].amax_reduction_group = self.tp_group
        else:
            # set grad_output_quantizer with amax epsilon and power_2_scale
            self.quantizers["scaling_bwd"][
                tex.FP8BwdTensors.GRAD_OUTPUT1
            ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
            self.quantizers["scaling_bwd"][
                tex.FP8BwdTensors.GRAD_OUTPUT1
            ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
            # parallel related
            if self.sequence_parallel and self.parallel_mode == "row":
                # customize grad_output_quantizer with amax reduction TP group
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].with_amax_reduction = True
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].amax_reduction_group = self.tp_group
1556
1557
1558
1559
1560
1561
1562
1563

    def _get_weight_quantizers(self) -> List[Quantizer]:
        """Get the weight quantizers of the module."""
        if not self.fp8:
            return [None]
        weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
        weight_quantizer.internal = True
        return [weight_quantizer]
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582

    def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
        """Customize quantizers based on blockwise scaling recipe + linear."""
        assert (
            recipe.float8_block_scaling()
        ), "blockwise scaling recipe quantizer customization here"

        if fwd:
            if self.sequence_parallel and self.parallel_mode == "column":
                # set compact for inp tensor X
                self.quantizers["scaling_fwd"][
                    tex.FP8FwdTensors.GEMM1_INPUT
                ].all_gather_usage = True
        else:
            if self.sequence_parallel and self.parallel_mode == "row":
                # set compact for grad_output tensor dY
                self.quantizers["scaling_bwd"][
                    tex.FP8BwdTensors.GRAD_OUTPUT1
                ].all_gather_usage = True