"...gmock/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "47be72a952e672e2635c62353d25e611e9a70dac"
linear.py 46 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Linear API"""
6
from typing import Callable, Dict, Optional, Tuple, Union
7
8
from functools import reduce
from operator import mul as multiply_op
9
10
11

import torch

12
import transformer_engine_torch as tex
13
14
15
16
17
18
19
20
21

from .base import (
    get_workspace,
    get_ub,
    TransformerEngineBaseModule,
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
22
23
from ._common import noop_cat, _fix_gathered_fp8_transpose
from ..fp8 import FP8GlobalStateManager
24
25
from ..utils import (
    cast_if_needed,
26
    clear_tensor_data,
27
    divide,
28
    init_method_constant,
29
    non_tn_fp8_gemm_supported,
30
    assert_dim_for_fp8_exec,
31
32
33
    nvtx_range_pop,
    nvtx_range_push,
    requires_grad,
34
35
36
37
38
39
40
)
from ..distributed import (
    set_tensor_model_parallel_attributes,
    get_distributed_world_size,
    allreduce,
    reduce_scatter_along_first_dim,
    gather_along_first_dim,
41
    is_fp8_activation_recompute_enabled,
42
    in_fp8_activation_recompute_phase,
43
44
    _fsdp_scatter_tensors,
    _fsdp_gather_tensors,
45
46
)
from ..cpp_extensions import (
47
    general_gemm,
48
)
49
from ..constants import GemmParallelModes, dist_group_type
50
from ..jit import no_torch_dynamo
51
from ..graph import is_graph_capturing
52
53
54
55
56
57
58
59
from ..tensor.quantized_tensor import (
    QuantizedTensor,
    Quantizer,
    prepare_for_saving,
    restore_from_saved,
)

from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
60

61
62
63
64
65
66
67
68
69
70
71
__all__ = ["Linear"]


class _Linear(torch.autograd.Function):
    """Linear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
72
        weight: torch.Tensor,
73
        inp: torch.Tensor,
74
        bias: Optional[torch.Tensor],
75
76
77
        is_first_microbatch: Union[bool, None],
        fp8: bool,
        fp8_calibration: bool,
78
79
80
81
82
        input_quantizer: Optional[Quantizer],
        weight_quantizer: Optional[Quantizer],
        output_quantizer: Optional[Quantizer],
        grad_output_quantizer: Optional[Quantizer],
        grad_input_quantizer: Optional[Quantizer],
83
        fuse_wgrad_accumulation: bool,
84
        cpu_offloading: bool,
85
86
87
88
89
90
91
        tp_group: Union[dist_group_type, None],
        tp_size: int,
        sequence_parallel: bool,
        tensor_parallel: bool,
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
        is_grad_enabled: bool,
92
93
94
95
96
97
        ub_overlap_rs_fprop: bool,
        ub_overlap_ag_dgrad: bool,
        ub_overlap_ag_fprop: bool,
        ub_overlap_rs_dgrad: bool,
        ub_bulk_dgrad: bool,
        ub_bulk_wgrad: bool,
98
        ub_name: str,
99
        fp8_output: bool,  # pylint: disable=unused-argument
100
        fsdp_group: Union[dist_group_type, None],
101
102
        module: torch.nn.Module,
        skip_fp8_weight_update: bool,
103
    ) -> torch.Tensor:
104
        # pylint: disable=missing-function-docstring
105

106
107
108
109
110
        # NVTX label for profiling
        nvtx_label = "transformer_engine._Linear.forward"
        if ub_name is not None:
            nvtx_label = f"{nvtx_label}.{ub_name}"

111
        # Make sure input dimensions are compatible
112
113
114
        out_features, in_features = weight.shape
        inp_shape = inp.shape
        assert inp_shape[-1] == in_features, "GEMM not possible"
115

116
        tp_world_size = get_distributed_world_size(tp_group)
117
118
119
120
        backward_needs_input = is_grad_enabled and weight.requires_grad

        # Prepare input tensor
        # Note: Cast to expected dtype and perform tensor-parallel communication
121
        nvtx_range_push(f"{nvtx_label}.input_cast_comm")
122
        inputmat = inp.view(-1, in_features)
123
124
125
126
127
        inputmat_total = None
        with_input_all_gather_nccl = (
            parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
        )
        own_quantized_input = False
128
        if fp8:
129
            assert_dim_for_fp8_exec(inputmat, weight)
130
131
132
133
134
135
            if (
                any([ub_overlap_ag_fprop, ub_overlap_rs_fprop])
                and not FP8GlobalStateManager.get_fp8_recipe().delayed()
            ):
                raise NotImplementedError(
                    "Comm+GEMM overlap is only supported with FP8 delayed scaling"
136
                )
137

138
139
140
141
142
143
144
145
146
147
148
149
            if input_quantizer is None:
                raise ValueError("Missing quantizer for input tensor")
            if with_input_all_gather_nccl:
                assert not isinstance(
                    inputmat, QuantizedTensor
                ), "All gather of fp8 input is not supported"
                input_quantizer.set_usage(rowwise=True, columnwise=False)
                inputmat_total, _ = gather_along_first_dim(
                    inputmat,
                    tp_group,
                    quantizer=input_quantizer,
                )
150
            else:
151
152
153
                input_quantizer.set_usage(
                    rowwise=True,
                    columnwise=backward_needs_input,
154
                )
155
156
157
158
159
                if not isinstance(inputmat, QuantizedTensor):
                    inputmat = input_quantizer(inputmat)
                elif backward_needs_input:
                    inputmat.update_usage(rowwise_usage=True, columnwise_usage=True)
                inputmat_total = inputmat
160
        else:
161
162
163
            inputmat = cast_if_needed(inp, activation_dtype)
            if with_input_all_gather_nccl:
                inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
164
            else:
165
                inputmat_total = inputmat
166
        nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
167

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        # Cast weight to expected dtype
        weightmat = weight
        if not fp8:
            weightmat = cast_if_needed(weightmat, activation_dtype)
        else:
            if not isinstance(weight, QuantizedTensor):
                # Configure quantizer
                if weight_quantizer is not None:
                    columnwise_usage = is_grad_enabled and inp.requires_grad
                    if not columnwise_usage:
                        columnwise_usage = (
                            is_fp8_activation_recompute_enabled()
                            and not in_fp8_activation_recompute_phase()
                        )
                    weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)

                # FP8 cast to workspace buffer
                update_workspace = is_first_microbatch is None or is_first_microbatch
                weightmat = module.get_weight_workspace(
                    tensor=weight,
                    quantizer=weight_quantizer,
                    cache_name=(None if is_first_microbatch is None else "weight"),
                    update_workspace=update_workspace,
                    skip_update_flag=skip_fp8_weight_update,
                    fsdp_group=fsdp_group,
                )

        # Cast bias to expected dtype
        bias_dtype = activation_dtype
        if fp8 and activation_dtype == torch.float32:
            bias_dtype = torch.bfloat16
        bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias

        # Configure output quantizer
        if output_quantizer is not None:
            output_quantizer.set_usage(rowwise=True, columnwise=False)

        # Calibrate quantizers if needed
        if not fp8 and fp8_calibration:
            if input_quantizer is not None:
                input_quantizer.calibrate(inputmat_total)
            if weight_quantizer is not None:
                weight_quantizer.calibrate(weight)

        ub_obj = None
        ub_type = None
        rs_out = None
        out_dtype = activation_dtype
        if ub_overlap_rs_fprop:
            ub_obj = get_ub(ub_name + "_fprop")
            ub_type = tex.CommOverlapType.RS
            out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features]
            rs_out = torch.empty(out_shape, dtype=activation_dtype, device=inputmat_total.device)

        elif ub_overlap_ag_fprop:
            ub_obj = get_ub(ub_name + "_fprop")
            ub_type = tex.CommOverlapType.AG
            if fp8:
                assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer."
            ub_obj.copy_into_buffer(inputmat_total, input_quantizer, local_chunk=True)
            inputmat_total = ub_obj.get_buffer(input_quantizer)

230
        nvtx_range_push(f"{nvtx_label}.gemm")
231
232
233
234
235
236
237
238
239
240
241
242
        out, *_, rs_out = general_gemm(
            weightmat,
            inputmat_total,
            get_workspace(),
            quantization_params=output_quantizer,
            out_dtype=out_dtype,
            bias=bias,
            use_split_accumulator=_2X_ACC_FPROP,
            ub=ub_obj,
            ub_type=ub_type,
            extra_output=rs_out,
        )
243
        nvtx_range_pop(f"{nvtx_label}.gemm")
244
245

        if is_grad_enabled:
246
            saved_inputmat = None
247
248
249
250
            if backward_needs_input:
                if own_quantized_input and isinstance(inputmat, QuantizedTensor):
                    inputmat.update_usage(rowwise_usage=False)
                saved_inputmat = inputmat
251

252
253
254
255
256
            if cpu_offloading:
                set_offloading_param(weight, "weight_offloading", True)
                set_offloading_param(weightmat, "weight_offloading", True)
                if saved_inputmat is not None:
                    set_offloading_param(saved_inputmat, "activation_offloading", True)
257

258
259
            # Scatter intermediate/activation tensors saved for the backward pass
            # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights
260
            nvtx_range_push(f"{nvtx_label}.fsdp_scatter")
261
262
263
            ctx.fsdp_group = fsdp_group
            ctx.fsdp_shapes = _fsdp_scatter_tensors(
                fsdp_group,
264
265
                saved_inputmat,
                weightmat if fp8 and not isinstance(weight, QuantizedTensor) else None,
266
            )
267
            nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
268

269
270
            # TODO(ksivamani): Check memory usage
            tensors_to_save, tensor_objects = prepare_for_saving(
271
                saved_inputmat,
272
                weightmat,
273
                weight,
274
                bias,
275
            )
276
277
            ctx.save_for_backward(*tensors_to_save)
            ctx.tensor_objects = tensor_objects
278

279
280
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
281
282
283
            ctx.input_quantizer = input_quantizer
            ctx.grad_output_quantizer = grad_output_quantizer
            ctx.grad_input_quantizer = grad_input_quantizer
284
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
285
286
287
            if fuse_wgrad_accumulation and weight.requires_grad:
                ctx.main_grad = weight.main_grad

288
            ctx.cpu_offloading = cpu_offloading
289
            ctx.is_first_microbatch = is_first_microbatch
290
            ctx.use_bias = bias is not None
291
292
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
293
            ctx.inp_shape = inp_shape
294
295
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
296
297
298
299
            ctx.ub_overlap_ag = ub_overlap_ag_dgrad
            ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
            ctx.ub_bulk_dgrad = ub_bulk_dgrad
            ctx.ub_bulk_wgrad = ub_bulk_wgrad
300
            ctx.ub_name = ub_name
301
302
            ctx.tp_size = tp_size
            ctx.requires_dgrad = inp.requires_grad
303
            ctx.requires_wgrad = weight.requires_grad
304
            ctx.reduce_and_update_bwd_fp8_tensors = False
305
306
            ctx.owns_input = saved_inputmat is not inp
            ctx.is_input_fp8 = not own_quantized_input
307
            if ctx.fp8 and requires_grad(inp, weight, bias):
308
309
310
311
                _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
                ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
                if in_fp8_activation_recompute_phase():
                    FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
312
313

        # Row Parallel Linear
314
315
316
        if ub_overlap_rs_fprop:
            out = rs_out
        elif parallel_mode == "row":
317
            nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
318
            if sequence_parallel:
319
320
321
                out, _ = reduce_scatter_along_first_dim(out, tp_group)
            elif tensor_parallel:
                out, _ = allreduce(out, tp_group)
322
            nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
323

324
325
        out = out.view(-1, *inp_shape[1:-1], out_features)
        return out
326
327

    @staticmethod
328
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
329
        # pylint: disable=missing-function-docstring
330

331
332
333
334
335
        # NVTX label for profiling
        nvtx_label = "transformer_engine._Linear.backward"
        if ctx.ub_name is not None:
            nvtx_label = f"{nvtx_label}.{ctx.ub_name}"

336
        with torch.cuda.nvtx.range("_Linear_backward"):
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
            if (
                ctx.fp8
                and any(
                    [
                        ctx.ub_overlap_ag,
                        ctx.ub_overlap_rs_dgrad,
                        ctx.ub_bulk_dgrad,
                        ctx.ub_bulk_wgrad,
                    ]
                )
                and not FP8GlobalStateManager.get_fp8_recipe().delayed()
            ):
                raise NotImplementedError(
                    "Comm+GEMM overlap is only supported with FP8 delayed scaling"
                )

            saved_tensors = ctx.saved_tensors
            inputmat, weight_fp8, weight, bias = (  # pylint: disable=unbalanced-tuple-unpacking
                restore_from_saved(ctx.tensor_objects, saved_tensors)
            )
357
358
359
            # Delete the references to tensor objects once they've been consumed
            # by the `restore_from_saved` method to construct back the actual tensors.
            ctx.tensor_objects = None
360
361
362
363
364
365
366
367
368
369
370

            # Since main_grad can be modified inplace, it should not be a part of saved_tensors
            main_grad = (
                ctx.main_grad
                if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad
                else None
            )

            if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
                weight = torch.nn.Parameter(weight, weight.requires_grad)
                weight.main_grad = main_grad
371

372
373
374
            # Gather intermediate/activation tensors if needed
            # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
            #       shards/unshards the base weights so we don't do it ourselves
375
            nvtx_range_push(f"{nvtx_label}.fsdp_gather")
376
377
378
379
            _fsdp_gather_tensors(
                ctx.fsdp_group,
                ctx.fsdp_shapes,
                inputmat,
380
                weight_fp8,
381
            )
382
            nvtx_range_pop(f"{nvtx_label}.fsdp_gather")
383

384
            ctx.ub_obj_gradout = None
385
            ub_obj_dgrad = None
386
            ub_obj_wgrad = None
387
388
            ub_type_dgrad = None
            ub_type_wgrad = None
389
            dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
390
391
            rs_out = None
            dgrad_bulk = None
392
            if ctx.ub_overlap_ag:
393
                # Overlap grad_output all-gather with dgrad compute
394
                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
395
396
                ub_obj_dgrad = ctx.ub_obj_gradout
                ub_type_dgrad = tex.CommOverlapType.AG
397
398
399
400

            elif ctx.ub_overlap_rs_dgrad:
                # Overlap dgrad reduce-scatter with dgrad compute
                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
401
402
                ub_obj_dgrad = ctx.ub_obj_gradout
                ub_type_dgrad = tex.CommOverlapType.RS
403
404
405
406
407
408
409
                rs_out = torch.empty(
                    dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
                )

            else:
                if ctx.ub_bulk_dgrad:
                    # Overlap inputmat all-gather with dgrad compute
410
411
412
413
                    # NOTE: Copying into communication buffer will always prefer rowwise data,
                    #       and will copy columnwise data if rowwise does not exist. In that case,
                    #       the all-gather will apply to the leading dimension of the transpose,
                    #       which then needs to be interleaved correctly before WGRAD.
414
                    ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
415
416
417
                    ub_obj_dgrad = ctx.ub_obj_gradout
                    ub_type_dgrad = tex.CommOverlapType.AG
                    ub_obj_dgrad.copy_into_buffer(inputmat, ctx.input_quantizer, local_chunk=True)
418
419
420
421

                if ctx.ub_bulk_wgrad:
                    # Overlap dgrad reduce-scatter with wgrad compute
                    ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
422
423
424
425
426
427
428
429
                    ub_type_wgrad = tex.CommOverlapType.RS
                    ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
                    dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer)

            # Prepare grad output tensor
            # Note: Cast to expected dtype and perform tensor-parallel communication
            if ctx.grad_output_quantizer is not None:
                ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
430
            nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
431
432
433
434
            (
                grad_output,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
435
436
437
438
                ctx,
                grad_output,
                ctx.parallel_mode == "row",
                ctx.grad_output_quantizer,
439
            )
440
            nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
441

442
443
            # Prepare input tensor
            # Note: Perform tensor-parallel communication if needed
444
            inputmat_total = None
445
            inputmat_total_work = None
446
            if (
447
                ctx.requires_wgrad
448
449
450
451
                and ctx.parallel_mode == "column"
                and ctx.sequence_parallel
                and not ctx.ub_bulk_dgrad
            ):
452
453
454
455
                quantizer = None
                if ctx.fp8:
                    quantizer = ctx.input_quantizer
                    quantizer.set_usage(rowwise=True, columnwise=True)
456
                nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
457
458
459
460
461
                inputmat_total, inputmat_total_work = gather_along_first_dim(
                    inputmat,
                    ctx.tp_group,
                    async_op=True,
                    quantizer=quantizer,
462
                )
463
                nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
464
465
466
            else:
                inputmat_total = inputmat

467
            # Check whether to output wgrad GEMM directly into main grad
468
469
470
471
472
473
474
            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                )
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation

475
476
477
            # Compute grad input tensor
            dgrad = None
            dgrad_work = None
478
            if ctx.requires_dgrad:
479
480
481
482
483
484

                # Update quantizer
                if ctx.grad_input_quantizer is not None:
                    ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)

                # dgrad GEMM
485
                nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
                dgrad, *_, rs_out = general_gemm(
                    weight_fp8,
                    grad_output,
                    get_workspace(),
                    layout="NN",
                    grad=True,
                    quantization_params=ctx.grad_input_quantizer,
                    out=dgrad_bulk,
                    out_dtype=ctx.activation_dtype,
                    use_split_accumulator=_2X_ACC_DGRAD,
                    ub=ub_obj_dgrad,
                    ub_type=ub_type_dgrad,
                    extra_output=rs_out,
                    bulk_overlap=ctx.ub_bulk_dgrad,
                )
501
                nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
502
503
504
505
506

                # Launch tensor-parallel communication
                if ctx.ub_overlap_rs_dgrad:
                    dgrad = rs_out
                elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad:
507
                    nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
508
509
510
511
512
                    if ctx.sequence_parallel:
                        dgrad, dgrad_work = reduce_scatter_along_first_dim(
                            dgrad,
                            ctx.tp_group,
                            async_op=True,
513
                        )
514
                    else:
515
                        dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
516
                    nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
517

518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
            # Compute grad weight tensor
            wgrad = None
            if ctx.requires_wgrad:
                if ctx.ub_bulk_dgrad:
                    inputmat_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer)
                    if ctx.fp8:
                        if inputmat._data is None:
                            # All-gather executed on columnwise data and result is in rowwise data,
                            # so we need to fix the interleaving before WGRAD.
                            inputmat_total = _fix_gathered_fp8_transpose(
                                inputmat_total, ctx.tp_size
                            )
                        elif not non_tn_fp8_gemm_supported():
                            # FP8 GEMM on Hopper only supports TN layout so the gathered input must
                            # have a valid transpose.
                            inputmat_total._create_transpose()
534

535
                else:
536
537
538
539
540
541
542
543
544
545
546
547
548
                    if inputmat_total_work is not None:
                        # Synchronize tensor-parallel communication
                        inputmat_total_work.wait()
                        inputmat_total_work = None

                if isinstance(grad_output, QuantizedTensor):
                    # This is a no-op if platform supports non-TN FP8 GEMM or the transpose
                    # already exists.
                    grad_output.update_usage(rowwise_usage=True, columnwise_usage=True)

                if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
                    rs_out = torch.empty(
                        dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
549
550
                    )

551
552
                # wgrad GEMM
                # Note: Fuse with bgrad computation if needed
553
                nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
                wgrad, grad_bias_, _, rs_out = general_gemm(
                    inputmat_total,
                    grad_output,
                    get_workspace(),
                    layout="NT",
                    grad=True,
                    out_dtype=(
                        main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
                    ),
                    bias=(bias if (grad_bias is None and not ctx.fp8) else None),
                    out=main_grad if ctx.fuse_wgrad_accumulation else None,
                    use_split_accumulator=_2X_ACC_WGRAD,
                    accumulate=accumulate_wgrad_into_param_main_grad,
                    ub=ub_obj_wgrad,
                    ub_type=ub_type_wgrad,
                    extra_output=rs_out,
                    bulk_overlap=ctx.ub_bulk_wgrad,
                )
572
                nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
573

574
575
576
                if ctx.ub_bulk_wgrad:
                    if ub_obj_wgrad.is_fp8_ubuf():
                        dgrad = rs_out
577
                    else:
578
                        dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True)
579

580
581
582
                if grad_bias is None:
                    grad_bias = grad_bias_
                del grad_bias_
583

584
                # Deallocate input tensor
585
586
                if ctx.owns_input:
                    clear_tensor_data(inputmat_total)
587

588
            # Don't return grad bias if not needed
589
590
591
            if not ctx.use_bias:
                grad_bias = None

592
593
594
595
596
597
598
599
600
            # Synchronize tensor parallel communication
            if inputmat_total_work is not None:
                inputmat_total_work.wait()
                inputmat_total_work = None
            if dgrad_work is not None:
                dgrad_work.wait()
                dgrad_work = None

        if ctx.requires_wgrad:
601
            # Handle custom DDP from mcore.
602
603
604
605
606
            if (
                ctx.fuse_wgrad_accumulation
                and weight is not None
                and hasattr(weight, "grad_added_to_main_grad")
            ):
607
                weight.grad_added_to_main_grad = True
608
609
610
611
612
613
614
                if getattr(weight, "zero_out_wgrad", False):
                    wgrad = torch.zeros(
                        weight.main_grad.shape,
                        dtype=weight.dtype,
                        device=torch.cuda.current_device(),
                        requires_grad=False,
                    )
615
                else:
616
617
618
619
620
621
                    wgrad = torch.empty(
                        weight.main_grad.shape,
                        dtype=weight.dtype,
                        device=torch.cuda.current_device(),
                        requires_grad=False,
                    )
622
623
624
625
            elif ctx.fuse_wgrad_accumulation:
                wgrad = None
        else:
            wgrad = None
626

627
        if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
628
            nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
629
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
630
            nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
631

632
        # Scatter fp8 weight buffers
633
        if ctx.fp8 and not isinstance(weight, QuantizedTensor):
634
            _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
635
        return (
636
            wgrad,
637
638
            dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
            grad_bias,
639
640
641
            None,  # is_first_microbatch
            None,  # fp8
            None,  # fp8_calibration
642
643
644
645
646
            None,  # input_quantizer
            None,  # weight_quantizer
            None,  # output_quantizer
            None,  # grad_output_quantizer
            None,  # grad_input_quantizer
647
648
649
650
651
652
653
654
655
            None,  # fuse_wgrad_accumulation
            None,  # cpu_offloading
            None,  # tp_group
            None,  # tp_size
            None,  # sequence_parallel
            None,  # tensor_parallel
            None,  # activation_dtype
            None,  # parallel_mode
            None,  # is_grad_enabled
656
657
658
659
660
661
            None,  # ub_overlap_rs_fprop
            None,  # ub_overlap_ag_dgrad
            None,  # ub_overlap_ag_fprop
            None,  # ub_overlap_rs_dgrad
            None,  # ub_bulk_dgrad
            None,  # ub_bulk_wgrad
662
            None,  # ub_name
663
            None,  # fp8_output
664
            None,  # fsdp_group
665
666
            None,  # module
            None,  # skip_fp8_weight_update
667
668
669
670
        )


class Linear(TransformerEngineBaseModule):
671
    """Applies a linear transformation to the incoming data :math:`y = xA^T + b`
672
673
674
675
676
677
678
679
680
681
682
683
684
685

    On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
686
    get_rng_state_tracker : Callable, default = `None`
687
                 used to get the random number generator state tracker for initializing weights.
688
689
    rng_tracker_name : str, default = `None`
                 the param passed to get_rng_state_tracker to get the specific rng tracker.
cyanguwa's avatar
cyanguwa committed
690
    parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
691
692
693
694
695
696
697
                      Configuration for splitting the weight and bias tensors along dim 0 into
                      multiple PyTorch parameters. If a list or tuple of strings is provided,
                      they are used to make the names of equally-sized parameters. If a dict
                      (preferably an OrderedDict) is provided, the keys are used as names and
                      values as split sizes along dim 0. The resulting parameters will have
                      names that end in `_weight` or `_bias`, so trailing underscores are
                      stripped from any provided names.
698
    device : Union[torch.device, str], default = "cuda"
699
          The device on which the parameters of the model will be allocated. It is the user's
700
701
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
702
703
704
705
706
707
708
709
710
711
712
713
714

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
715
    parallel_mode : {None, 'column', 'row'}, default = `None`
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
733
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
734
735
736
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
737

738
739
740
741
742
743
744
745
746
747
748
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
749
        rng_tracker_name: Optional[str] = None,
750
751
752
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
753
        params_dtype: Optional[torch.dtype] = None,
754
        parallel_mode: Optional[str] = None,
cyanguwa's avatar
cyanguwa committed
755
        parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
756
        device: Union[torch.device, str] = "cuda",
757
        ub_overlap_ag: bool = False,
758
        ub_overlap_rs: bool = False,
759
        ub_overlap_rs_dgrad: bool = False,
760
761
        ub_bulk_dgrad: bool = False,
        ub_bulk_wgrad: bool = False,
762
        ub_name: Optional[str] = None,
763
764
    ) -> None:
        super().__init__()
765
766

        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
767
768
769
770
771
772
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
        self.apply_bias = bias and not return_bias
773
774
        self.get_rng_state_tracker = get_rng_state_tracker
        self.rng_tracker_name = rng_tracker_name
775

776
777
        if device == "meta":
            assert parameters_split is None, "Cannot split module parameters on 'meta' device."
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

799
        # Column parallel TP overlap options
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
        self.ub_overlap_ag_fprop = (
            self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_ag
        )
        self.ub_overlap_rs_dgrad = (
            self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_rs_dgrad
        )
        self.ub_bulk_dgrad = (
            self.parallel_mode == "column"
            and self.sequence_parallel
            and ub_bulk_dgrad
            and not self.ub_overlap_rs_dgrad
        )
        self.ub_bulk_wgrad = (
            self.parallel_mode == "column"
            and self.sequence_parallel
            and ub_bulk_wgrad
            and not self.ub_overlap_rs_dgrad
        )
818
819

        # Row parallel TP overlap options
820
821
822
823
824
825
        self.ub_overlap_rs_fprop = (
            self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_rs
        )
        self.ub_overlap_ag_dgrad = (
            self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_ag
        )
826
827
828
829
830
831
832
833
834
835
836
837
838
839

        if any(
            [
                self.ub_overlap_rs_fprop,
                self.ub_overlap_ag_dgrad,
                self.ub_overlap_ag_fprop,
                self.ub_overlap_rs_dgrad,
                self.ub_bulk_dgrad,
                self.ub_bulk_wgrad,
            ]
        ):
            assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized."
        self.ub_name = ub_name

840
841
842
        # Initialize params in FP8
        with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()

843
844
845
846
847
848
849
850
        # Contiguous buffers for params
        weight_tensor = torch.empty(
            self.out_features,
            self.in_features,
            device=device,
            dtype=params_dtype,
        )
        bias_tensor = None
851
        if self.use_bias:
852
853
854
855
856
            bias_tensor = torch.empty(
                self.out_features,
                device=device,
                dtype=params_dtype,
            )
857

858
859
860
861
        # Configure parameter splits
        self.weight_names = []
        self.bias_names = []
        self.parameter_split_sizes = []
862
        if parameters_split is None:
863
864
865
866
867
868
            # Split into a single parameter by default
            self.weight_names = ["weight"]
            self.bias_names = ["bias"]
            self.parameter_split_sizes = [out_features]
        elif not parameters_split:
            raise ValueError("Cannot split weight buffer into 0 parameters")
cyanguwa's avatar
cyanguwa committed
869
        elif isinstance(parameters_split, dict):
870
871
872
873
874
875
876
877
878
879
880
881
            # Split parameters with provided sizes
            for name, split_size in parameters_split.items():
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
        elif all(isinstance(name, str) for name in parameters_split):
            # Split parameters evenly
            split_size = out_features // len(parameters_split)
            for name in parameters_split:
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
cyanguwa's avatar
cyanguwa committed
882
        else:
883
            raise TypeError("Invalid configuration for parameters split")
884

885
886
887
888
889
890
        # Make sure parameter splits are valid
        if sum(self.parameter_split_sizes) != out_features:
            raise ValueError(
                f"Trying to split weight buffer ({out_features=}) "
                f"with split sizes {self.parameter_split_sizes}"
            )
891

892
893
894
895
896
897
898
899
900
901
        # Adjust parameter splits for tensor-parallel distribution
        if self.parallel_mode == "column":
            for i, size in enumerate(self.parameter_split_sizes):
                if size % self.tp_size != 0:
                    raise RuntimeError(
                        f"Attempting to distribute a parameter with out_features={size} "
                        f"between {self.tp_size} tensor-parallel processes"
                    )
                self.parameter_split_sizes[i] = size // self.tp_size

902
903
904
905
906
        # Construct weight parameters
        # Note: Register weights together so that they are adjacent to
        # each other in Linear.parameters(). This makes it more likely
        # that they will stay contiguous if the weights are
        # manipulated externally, e.g. by FSDP.
907
908
909
910
911
912
913
914
        offset = 0
        for i, split_size in enumerate(self.parameter_split_sizes):
            split_start = offset
            offset += split_size
            split_end = offset

            # Check if parameters are subviews of buffers
            is_subview = (split_start, split_end) != (0, self.out_features)
915
            if is_subview and with_fp8_params:
916
917
918
                raise RuntimeError(
                    "Splitting QuantizedTensor into multiple params is not supported"
                )
919

920
            # Construct weight parameter
921
922
923
924
925
926
927
            self.register_parameter(
                self.weight_names[i],
                torch.nn.Parameter(weight_tensor[split_start:split_end]),
                init_fn=init_method,
                get_rng_state_tracker=get_rng_state_tracker,
                fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
            )
928

929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
        # Construct bias parameters if needed
        if self.use_bias:
            offset = 0
            for i, split_size in enumerate(self.parameter_split_sizes):
                split_start = offset
                offset += split_size
                split_end = offset
                self.register_parameter(
                    self.bias_names[i],
                    torch.nn.Parameter(bias_tensor[split_start:split_end]),
                    init_fn=init_method_constant(0.0),
                )
        else:
            for name in self.bias_names:
                bias = torch.Tensor().to(dtype=params_dtype, device=device)
                setattr(self, name, bias)
cyanguwa's avatar
cyanguwa committed
945

946
        if with_fp8_params:
947
948
            self.init_fp8_metadata()

949
        self.reset_parameters(defer_init=device == "meta")
950

951
952
953
954
955
956
957
        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.apply_bias:
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
    def reset_parameters(self, defer_init=False):
        super().reset_parameters(defer_init=defer_init)

        if not defer_init:
            # Set parallelism attributes for linear weights
            for weight in self.weight_names:
                set_tensor_model_parallel_attributes(
                    tensor=getattr(self, weight),
                    is_parallel=True,
                    dim=1 if self.parallel_mode == "row" else 0,
                    stride=1,
                )

            # Set parallelism attributes for linear biases
            if self.use_bias:
                for bias in self.bias_names:
                    if self.parallel_mode == "row":
                        setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel)
                    elif self.parallel_mode == "column":
                        set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)

979
    @no_torch_dynamo()
980
981
982
983
    def forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Optional[bool] = None,
984
        fp8_output: Optional[bool] = False,
985
        fp8_grad: Optional[bool] = False,
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """
1008
1009
1010
1011
        if FP8GlobalStateManager.fp8_graph_capturing():
            skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
        else:
            skip_fp8_weight_update = None
1012
1013
1014
        if skip_fp8_weight_update is not None:
            is_first_microbatch = False

1015
1016
        with self.prepare_forward(
            inp,
1017
            allow_non_contiguous=isinstance(inp, QuantizedTensor),
1018
        ) as inp:
1019
1020

            # Get concatenated weight and bias tensors
1021
            unfused_weights = [getattr(self, name) for name in self.weight_names]
1022
            if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
1023
1024
1025
                if self.fp8:
                    if len(unfused_weights) != 1:
                        raise RuntimeError(
1026
                            "Splitting QuantizedTensor into multiple params is not supported"
1027
1028
                        )
                else:
1029
                    unfused_weights = [w.dequantize() for w in unfused_weights]
1030
            weight_tensor = noop_cat(unfused_weights)
1031
            if self.use_bias:
1032
                bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
1033
            else:
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
                bias_tensor = None

            (
                input_quantizer,
                weight_quantizer,
                output_quantizer,
                grad_output_quantizer,
                grad_input_quantizer,
            ) = self._get_quantizers(fp8_output, fp8_grad)

            # Make sure weight tensor has correct quantizer
            # Note: Quantizer might have changed if quantization
            # recipe changed
            if weight_quantizer is not None and isinstance(weight_tensor, QuantizedTensor):
                weight_tensor._quantizer = weight_quantizer
1049

1050
1051
1052
1053
1054
1055
1056
1057
1058
            if torch.is_grad_enabled():
                linear_fn = _Linear.apply
                args = []
            else:
                linear_fn = _Linear.forward
                args = [None]
            args += (
                weight_tensor,
                inp,
1059
                bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
1060
1061
1062
                is_first_microbatch,
                self.fp8,
                self.fp8_calibration,
1063
1064
1065
1066
1067
                input_quantizer,
                weight_quantizer,
                output_quantizer,
                grad_output_quantizer,
                grad_input_quantizer,
1068
                self.fuse_wgrad_accumulation,
1069
                is_cpu_offload_enabled(),
1070
1071
1072
1073
1074
1075
1076
                self.tp_group,
                self.tp_size,
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
                torch.is_grad_enabled(),
1077
1078
1079
1080
1081
1082
                self.ub_overlap_rs_fprop,
                self.ub_overlap_ag_dgrad,
                self.ub_overlap_ag_fprop,
                self.ub_overlap_rs_dgrad,
                self.ub_bulk_dgrad,
                self.ub_bulk_wgrad,
1083
                self.ub_name,
1084
                fp8_output,
1085
                self.fsdp_group,
1086
1087
                self,
                skip_fp8_weight_update,
1088
1089
1090
1091
1092
1093
1094
1095
            )
            out = linear_fn(*args)
        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        return out
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120

    def _get_quantizers(self, fp8_output, fp8_grad):
        if not self.fp8:
            return [None] * 5
        grad_input_quantizer = None
        grad_output_quantizer = None
        output_quantizer = None
        input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
        input_quantizer.internal = False
        weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
        weight_quantizer.internal = True
        if fp8_output:
            output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
        if torch.is_grad_enabled():
            grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
            grad_output_quantizer.internal = True
            if fp8_grad:
                grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
        return (
            input_quantizer,
            weight_quantizer,
            output_quantizer,
            grad_output_quantizer,
            grad_input_quantizer,
        )