linear.py 37.7 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Linear API"""
6
import warnings
7
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
8
9
10
11
12
13
14
15
16
17
18
19
20
21

import torch

import transformer_engine_extensions as tex

from .base import (
    get_workspace,
    _prepare_backward,
    get_ub,
    TransformerEngineBaseModule,
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
22
from ._common import _noop_cat
23
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
24
25
26
from ..utils import (
    divide,
    cast_if_needed,
27
    assert_dim_for_fp8_exec,
28
    clear_tensor_data,
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
)
from ..distributed import (
    set_tensor_model_parallel_attributes,
    get_distributed_world_size,
    allreduce,
    reduce_scatter_along_first_dim,
    gather_along_first_dim,
)
from ..cpp_extensions import (
    fp8_gemm,
    gemm,
    fp8_cast_transpose_fused,
    cast_to_fp8,
)
from ..constants import GemmParallelModes, dist_group_type
44
from ..jit import no_torch_dynamo
45

46
47
from ..float8_tensor import Float8Tensor

48
49
50
51
52
53
54
55
56
57
58
__all__ = ["Linear"]


class _Linear(torch.autograd.Function):
    """Linear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
59
60
61
        weight: Union[Float8Tensor, torch.Tensor],
        weight_fp8: Union[Float8Tensor, None],
        weight_t_fp8: Union[Float8Tensor, None],
62
63
64
65
66
67
68
69
        inp: torch.Tensor,
        bias: torch.Tensor,
        use_bias: bool,
        is_first_microbatch: Union[bool, None],
        fp8: bool,
        fp8_calibration: bool,
        fp8_meta: Dict[str, Any],
        fuse_wgrad_accumulation: bool,
70
        cpu_offloading: bool,
71
72
73
74
75
76
77
        tp_group: Union[dist_group_type, None],
        tp_size: int,
        sequence_parallel: bool,
        tensor_parallel: bool,
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
        is_grad_enabled: bool,
78
        primary_weights_in_fp8: bool,
79
80
        ub_split_rs: bool,
        ub_split_ag: bool,
81
82
        ub_atomic_gemm_rs: bool,
        ub_atomic_gemm_ag: bool,
83
        ub_name: str
84
85
86
87
88
    ) -> torch.Tensor:
        # Make sure input dimensions are compatible
        in_features = weight.shape[-1]
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.view((-1, in_features))
89
        if fp8:
90
91
            assert_dim_for_fp8_exec(inputmat)
            assert_dim_for_fp8_exec(weight)
92
93
94

        update_fp8_weights = is_first_microbatch is None or is_first_microbatch

95
        if ub_split_rs or ub_atomic_gemm_rs:
96
97
98
            tp_world_size = get_distributed_world_size(tp_group)
            if tp_world_size == 1:
                ub_split_rs = False
99
100
101
                ub_atomic_gemm_rs = False
        if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
            assert fp8, "AtomicGemm overlap supported only for FP8 GEMM."
102
103

        # Cast input to expected dtype
104
        inputmat = cast_if_needed(inputmat, activation_dtype)
105
        inputmat_t = None
106
107
108
        inputmat_no_fp8 = inputmat
        if fp8:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
109
110
111
112
113
114
115
116
117
118
119
120
121
            if (
                not fp8_meta["recipe"].override_linear_precision.wgrad
                and is_grad_enabled
                and weight.requires_grad
                and not sequence_parallel
            ):
                # FP8 input for forward, FP8 input transpose for backward wgrad
                inputmat, inputmat_t = fp8_cast_transpose_fused(
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    tex.FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
                )
122
            else:
123
124
                # FP8 input for forward
                inputmat = cast_to_fp8(
125
126
127
128
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    tex.FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
129
                )
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

        # Column Parallel Linear
        if parallel_mode == "column" and sequence_parallel:
            inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
        else:
            inputmat_total = inputmat

        if fp8:
            bias_dtype = (
                torch.bfloat16
                if activation_dtype == torch.float32
                else activation_dtype
            )
            bias = cast_if_needed(bias, bias_dtype) if use_bias else bias

145
146
147
148
149
150
151
152
153
154
155
156
            if primary_weights_in_fp8:
                # Weight is already in FP8
                weight.reset_fp8_meta_scale_inv()
                weight_fp8 = weight
                weight_t_fp8 = None
            elif update_fp8_weights:
                # Need to cast weights to FP8
                weight_fp8 = Float8Tensor(
                    data=weight_fp8._data,
                    fp8_meta=fp8_meta,
                    fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
                )
157
158
159
160
161
162
                if is_grad_enabled:
                    fp8_cast_transpose_fused(
                        weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
163
164
                        cast_out=weight_fp8._data,
                        transpose_out=weight_t_fp8._data,
165
166
                    )
                else:
167
                    weight_fp8._data = cast_to_fp8(
168
169
170
171
172
                        weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                    )
173
                    weight_t_fp8 = None
174

175
176
177
            proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
                None, None, None, activation_dtype)
            if ub_split_rs or ub_atomic_gemm_rs:
178
                ub_obj_projout = get_ub(ub_name+"_fprop")
179
180
181
182
183
                out = ub_obj_projout.get_ubuf_output(1)
                dim_size = list(inputmat_total.size())
                dim_size[0] = dim_size[0] // tp_world_size
                dim_size[1] = weight.size(0)
                rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
184
185
186
187
188
189
190

                if ub_obj_projout.is_fp8_ubuf():
                    proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT
                    meta_tensor = fp8_meta["scaling_fwd"]
                    proj_out_tetype = fp8_dtype_forward
                    proj_out_pttype = torch.uint8
                    ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index])
191
192
193
194
195
            else:
                dim_size = list(inputmat_total.size())
                dim_size[1] = weight.size(0)
                out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)

196
197
            ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None
            ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo
198
            _ = fp8_gemm(
199
                weight_fp8._data,
200
201
202
203
204
205
206
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_WEIGHT,
                fp8_dtype_forward,
                inputmat_total,
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_INPUT,
                fp8_dtype_forward,
207
                proj_out_pttype,
208
209
210
211
212
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
                use_split_accumulator=_2X_ACC_FPROP,
                out=out,
213
214
215
216
217
218
                ub_algo=ub_algo,
                ub=ub_obj_projout if (ub_split_rs or ub_atomic_gemm_rs) else None,
                extra_output_tensor=rs_out if (ub_split_rs or ub_atomic_gemm_rs) else None,
                out_index=proj_out_index,
                fp8_meta_tensor = meta_tensor,
                D_dtype = proj_out_tetype,
219
220
221
222
223
224
225
226
            )
        else:
            # Cast for native AMP
            weight = cast_if_needed(weight, activation_dtype)
            bias = cast_if_needed(bias, activation_dtype) if use_bias else bias

            if fp8_calibration:
                # amax of input
227
                amin, amax = inputmat_total.aminmax()
228
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
229
                    torch.max(-amin, amax).float()
230
                # amax of weight
231
                amin, amax = weight.aminmax()
232
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
233
                    torch.max(-amin, amax).float()
234
235
236
237
238
239
240
241
242
243
244
245
246

            if ub_split_rs:
                ub_obj_projout = get_ub("proj_fprop")
                out = ub_obj_projout.get_ubuf_output(1)
                dim_size = list(inputmat_total.size())
                dim_size[0] = dim_size[0] // tp_world_size
                dim_size[1] = weight.size(0)
                rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
            else:
                dim_size = list(inputmat_total.size())
                dim_size[1] = weight.size(0)
                out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)

247
            _ = gemm(
248
249
250
251
252
253
254
255
256
257
258
259
260
                weight,
                inputmat_total,
                activation_dtype,
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
                out=out,
                ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
                ub=ub_obj_projout if ub_split_rs else None,
                extra_output_tensor=rs_out if ub_split_rs else None,
            )

        if is_grad_enabled:
261
262
263
264
265
266
267
268
            saved_inputmat = None
            saved_inputmat_t = None
            if weight.requires_grad:
                if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad:
                    if inputmat_t is None:
                        saved_inputmat = inputmat
                    else:
                        saved_inputmat_t = inputmat_t
269
270
                        if cpu_offloading:
                            saved_inputmat_t.activation_offloading = True
271
272
                else:
                    saved_inputmat = inputmat_no_fp8
273
274
275
276
277
278
279
280
281
282
283

                if cpu_offloading:
                    if fuse_wgrad_accumulation:
                        weight.main_grad.weight_offloading = True
                    if fp8:
                        weight_t_fp8.weight_offloading = True
                    weight.weight_offloading = True

                    if saved_inputmat is not None:
                        saved_inputmat.activation_offloading = True

284
            ctx.save_for_backward(
285
286
                saved_inputmat,
                saved_inputmat_t,
287
                weight,
288
                weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
289
290
291
292
293
294
295
                weight_t_fp8 if fp8 else None,
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
            )
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
            ctx.fp8_meta = fp8_meta
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
296
            ctx.cpu_offloading = cpu_offloading
297
298
299
300
301
302
303
304
            ctx.is_first_microbatch = is_first_microbatch
            ctx.use_bias = use_bias
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
            ctx.inp_shape = inp.shape
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
            ctx.ub_split_ag = ub_split_ag
305
            ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag
306
            ctx.ub_name = ub_name
307
308
309
310
            ctx.tp_size = tp_size
            ctx.requires_dgrad = inp.requires_grad

        # Row Parallel Linear
311
        if ub_split_rs or ub_atomic_gemm_rs:
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
            out = rs_out
        elif parallel_mode == "row" and sequence_parallel:
            out, _ = reduce_scatter_along_first_dim(out, tp_group)
        elif parallel_mode == "row" and tensor_parallel:
            out, _ = allreduce(out, tp_group)

        # [*, in_features] -> [*, out_features] except first dimension changes for SP
        return out.view(-1, *inp.shape[1:-1], out.shape[-1])


    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        with _prepare_backward(
            ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear"
        ):
            (
                inputmat,
                inputmat_t,
                weight,
333
                main_grad,
334
335
336
                weight_t_fp8,
                fwd_scale_inverses,
            ) = ctx.saved_tensors
337

338
339
340
341
            if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
                weight = torch.nn.Parameter(weight, False)
                weight.main_grad = main_grad

342
343
344
            # Primary weights are in FP8.
            if ctx.fp8 and weight_t_fp8 is None:
                weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch)
345

346
            if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
347
348
349
                tp_world_size = get_distributed_world_size(ctx.tp_group)
                if tp_world_size == 1:
                    ctx.ub_split_ag = False
350
351
                    ctx.ub_atomic_gemm_ag = False
            if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
352
353
                dim_size = list(grad_output.size())
                dim_size[0] = dim_size[0] * tp_world_size
354
                ctx.ub_obj_gradout = get_ub(ctx.ub_name+"_dgrad")
355
356
357
358
359
360
361
362
363
364
365
            (
                grad_output,
                grad_output_c,
                grad_output_t,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
                ctx, grad_output, ctx.parallel_mode == "row"
            )

            # Column Parallel Linear
            # Overlap input AG with dgrad
366
367
368
            inputmat_total = None
            inputmat_t_total = None
            handle = None
369
            if weight.requires_grad and ctx.parallel_mode == "column" and ctx.sequence_parallel:
370
371
372
                inputmat_total, handle = gather_along_first_dim(
                    inputmat, ctx.tp_group, async_op=ctx.requires_dgrad
                )
373
374
            else:
                inputmat_total = inputmat
375
                inputmat_t_total = inputmat_t
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391

            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                )
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation

            if ctx.fp8:
                fp8_dtype_forward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=True
                )
                fp8_dtype_backward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=False
                )

392
393
            ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None
            ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo
394
395
            if ctx.requires_dgrad:
                if ctx.fp8:
396
                    dgrad, _ = fp8_gemm(
397
                        weight_t_fp8._data,
398
399
400
401
402
403
404
405
406
407
                        fwd_scale_inverses,
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                        grad_output_c,
                        ctx.fp8_meta["scaling_bwd"].scale_inv,
                        tex.FP8BwdTensors.GRAD_OUTPUT1,
                        fp8_dtype_backward,
                        ctx.activation_dtype,
                        get_workspace(),
                        use_split_accumulator=_2X_ACC_DGRAD,
408
409
                        ub_algo=ub_algo,
                        ub=ctx.ub_obj_gradout if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag else None,
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
                    )
                else:
                    dgrad, _, _ = gemm(
                        weight,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NN",
                        grad=True,
                        ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
                        ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
                    )

                # Overlap dgrad-RS/AR with wgrad
                if ctx.parallel_mode == "column" and ctx.sequence_parallel:
425
426
                    if handle is not None:
                        handle.wait()
427
428
429
430
431
432
433
434
435
436
                    dgrad, handle = reduce_scatter_along_first_dim(
                        dgrad, ctx.tp_group, async_op=True
                    )
                elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
                    dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)

            if weight.requires_grad:
                if ctx.fp8:
                    # WGRAD
                    if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
437
                        if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
438
                            grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
439
440
                        if inputmat_t_total is None:
                            inputmat_t_total = tex.fp8_transpose(inputmat_total, fp8_dtype_backward)
441
                        wgrad, _ = fp8_gemm(
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
                            inputmat_t_total,
                            fwd_scale_inverses,
                            tex.FP8FwdTensors.GEMM1_INPUT,
                            fp8_dtype_forward,
                            grad_output_t,
                            ctx.fp8_meta["scaling_bwd"].scale_inv,
                            tex.FP8BwdTensors.GRAD_OUTPUT1,
                            fp8_dtype_backward,
                            ctx.activation_dtype,
                            get_workspace(),
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                            use_split_accumulator=_2X_ACC_WGRAD,
                        )
                    else:
                        wgrad, _, _ = gemm(
                            inputmat_total,
                            grad_output,
                            ctx.activation_dtype,
                            get_workspace(),
                            layout="NT",
                            grad=True,
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                        )
                else:
                    # WGRAD
                    wgrad, grad_bias, _ = gemm(
                        inputmat_total,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NT",
                        grad=True,
                        use_bias=ctx.use_bias,
                        accumulate=accumulate_wgrad_into_param_main_grad,
                        out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                    )
480
481
482
483

                # Deallocate input tensor
                clear_tensor_data(inputmat_total)
                clear_tensor_data(inputmat_t_total)
484
485
486
487
488
489
490
491

            # Column Parallel Linear
            if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
                handle.wait()

            if not ctx.use_bias:
                grad_bias = None

492
493
494
495
        if weight.requires_grad:
            # Handle custom DDP from mcore.
            if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'):
                weight.grad_added_to_main_grad = True
496
497
498
499
500
501
502
503
504
505
506
507
                if getattr(weight, 'zero_out_wgrad', False):
                    wgrad = torch.zeros(weight.main_grad.shape,
                                        dtype=weight.dtype,
                                        device=torch.cuda.current_device(),
                                        requires_grad=False
                                       )
                else:
                    wgrad = torch.empty(weight.main_grad.shape,
                                        dtype=weight.dtype,
                                        device=torch.cuda.current_device(),
                                        requires_grad=False
                                       )
508
509
510
511
            elif ctx.fuse_wgrad_accumulation:
                wgrad = None
        else:
            wgrad = None
512

513
        return (
514
            wgrad,
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
            None,
            None,
            dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
            grad_bias,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
534
535
            None,
            None,
536
            None,
537
            None,
538
            None,
539
540
541
542
        )


class Linear(TransformerEngineBaseModule):
543
    """Applies a linear transformation to the incoming data :math:`y = xA^T + b`
544
545
546
547
548
549
550
551
552
553
554
555
556
557

    On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
cyanguwa's avatar
cyanguwa committed
558
    parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
559
560
561
562
563
564
565
                      Configuration for splitting the weight and bias tensors along dim 0 into
                      multiple PyTorch parameters. If a list or tuple of strings is provided,
                      they are used to make the names of equally-sized parameters. If a dict
                      (preferably an OrderedDict) is provided, the keys are used as names and
                      values as split sizes along dim 0. The resulting parameters will have
                      names that end in `_weight` or `_bias`, so trailing underscores are
                      stripped from any provided names.
566
567
568
569
    device : Union[torch.device, str], default = "cuda"
          The device on which the parameters of the model will allocated. It is the user's
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
    parallel_mode : {None, 'Column', 'Row'}, default = `None`
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
601
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
602
603
604
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
605

606
607
608
609
610
611
612
613
614
615
616
617
618
619
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
620
        params_dtype: Optional[torch.dtype] = None,
621
        parallel_mode: Optional[str] = None,
cyanguwa's avatar
cyanguwa committed
622
        parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
623
        device: Union[torch.device, str] = "cuda",
624
625
        ub_split_rs: bool = False,
        ub_split_ag: bool = False,
626
627
        ub_atomic_gemm_rs: bool = False,
        ub_atomic_gemm_ag: bool = False,
628
        ub_name: Optional[str] = None,
629
630
    ) -> None:
        super().__init__()
631
632

        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
633
634
635
636
637
638
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
        self.apply_bias = bias and not return_bias
639
        self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
640
641
        self.ub_split_rs = ub_split_rs
        self.ub_split_ag = ub_split_ag
642
643
        self.ub_atomic_gemm_rs = ub_atomic_gemm_rs
        self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
644
645
646
        if any([ub_atomic_gemm_rs, ub_atomic_gemm_ag]):
            assert ub_name is not None, "Userbuffer name [string] is not set."
        self.ub_name = ub_name
647
648
649
650
        self.get_rng_state_tracker = get_rng_state_tracker
        if device == 'meta':
            assert parameters_split is None, ("Cannot split module parameters "
                                              "on 'meta' device.")
651

652
        if ub_split_rs or ub_split_ag or ub_atomic_gemm_rs:
653
654
655
656
            assert (
                tex.userbuf_comm_available()
            ), "Userbuffer communication backend not available."

657
658
659
660
661
        if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
            warnings.warn(
                "Atomic gemm uses a beta API from cublas and is not tested for all use cases."
            )

662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

683
        self.weight_tensor = torch.empty(
684
            self.out_features, self.in_features,
685
            device=device, dtype=params_dtype)
686
687

        if self.use_bias:
688
            self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype)
689
        else:
690
            self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device)
691

692
693
694
695
        # Configure parameter splits
        self.weight_names = []
        self.bias_names = []
        self.parameter_split_sizes = []
696
        if parameters_split is None:
697
698
699
700
701
702
            # Split into a single parameter by default
            self.weight_names = ["weight"]
            self.bias_names = ["bias"]
            self.parameter_split_sizes = [out_features]
        elif not parameters_split:
            raise ValueError("Cannot split weight buffer into 0 parameters")
cyanguwa's avatar
cyanguwa committed
703
        elif isinstance(parameters_split, dict):
704
705
706
707
708
709
710
711
712
713
714
715
            # Split parameters with provided sizes
            for name, split_size in parameters_split.items():
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
        elif all(isinstance(name, str) for name in parameters_split):
            # Split parameters evenly
            split_size = out_features // len(parameters_split)
            for name in parameters_split:
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
cyanguwa's avatar
cyanguwa committed
716
        else:
717
            raise TypeError("Invalid configuration for parameters split")
718

719
720
721
722
723
724
        # Make sure parameter splits are valid
        if sum(self.parameter_split_sizes) != out_features:
            raise ValueError(
                f"Trying to split weight buffer ({out_features=}) "
                f"with split sizes {self.parameter_split_sizes}"
            )
725

726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
        # Adjust parameter splits for tensor-parallel distribution
        if self.parallel_mode == "column":
            for i, size in enumerate(self.parameter_split_sizes):
                if size % self.tp_size != 0:
                    raise RuntimeError(
                        f"Attempting to distribute a parameter with out_features={size} "
                        f"between {self.tp_size} tensor-parallel processes"
                    )
                self.parameter_split_sizes[i] = size // self.tp_size

        # Construct parameters from weight and bias buffers
        offset = 0
        for i, split_size in enumerate(self.parameter_split_sizes):
            split_start = offset
            offset += split_size
            split_end = offset

            # Check if parameters are subviews of buffers
            is_subview = (split_start, split_end) != (0, self.out_features)
            if is_subview and self.primary_weights_in_fp8:
                raise RuntimeError(
                    "Splitting Float8Tensor into multiple params "
                    "is not supported"
                )
750

751
752
753
754
755
            # Construct weight parameter
            weight = self.weight_tensor
            if is_subview:
                weight = weight[split_start:split_end]
            weight = torch.nn.Parameter(weight)
756
757
758
759
            self.register_parameter(self.weight_names[i], weight,
                                    init_fn=init_method,
                                    get_rng_state_tracker=get_rng_state_tracker,
                                    fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT)
cyanguwa's avatar
cyanguwa committed
760

761
762
763
764
765
766
767
768
769
            # Construct bias parameter if needed
            if self.use_bias:
                bias = self.bias_tensor
                if is_subview:
                    bias = bias[split_start:split_end]
                bias = torch.nn.Parameter(bias)
                self.register_parameter(self.bias_names[i], bias)
                if parallel_mode == "row":
                    bias.sequence_parallel = sequence_parallel
770
            else:
771
772
                bias = torch.Tensor().to(dtype=params_dtype, device=device)
                setattr(self, self.bias_names[i], bias)
773

774
            # Configure tensor parallelism
775
            set_tensor_model_parallel_attributes(
776
                tensor=weight,
777
778
779
780
781
                is_parallel=True,
                dim=1 if parallel_mode == "row" else 0,
                stride=1,
            )
            if parallel_mode == "column":
782
                set_tensor_model_parallel_attributes(bias, True, 0, 1)
783

784
785
786
787
788
            # Concatenated tensors are not needed if not splitting
            # into multiple parameters
            if not is_subview:
                del self.weight_tensor
                del self.bias_tensor
cyanguwa's avatar
cyanguwa committed
789

790
791
792
793
794
795
        if self.primary_weights_in_fp8:
            self.init_fp8_metadata()
            self.fp8_meta["update_amax_and_scale_fwd"] = True

        self.reset_parameters(defer_init=(device == 'meta'))

796
797
798
799
800
801
802
803
804
        self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))

        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.apply_bias:
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

805
806
807
    def get_fp8_weights_scratchpad(
        self,
        is_first_microbatch: Union[bool, None],
808
    ) -> List[Float8Tensor]:
809
810
811
812
813
        """
        Fetch the fp8 weight tensor placeholders if they exist (when
        `is_first_microbatch` is not `None`) or return empty fp8 weight
        tensors (if `is_first_microbatch is None`)
        """
814
        if not self.fp8 or self.primary_weights_in_fp8:
815
816
817
818
819
820
821
822
823
824
825
826
827
828
            return [None, None]

        if is_first_microbatch is None:
            # Return empty weight placeholders for each fwd/bwd pass
            fp8_weight_tensors = self.get_fp8_weights_empty_tensors(
                is_first_microbatch
            )
        else:
            # These persistent weight placeholders should've been created in
            # `set_fp8_weights` method
            fp8_weight_tensors = [self.weight1_fp8, self.weight1_t_fp8]

        return fp8_weight_tensors

829
    @no_torch_dynamo()
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
    def forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Optional[bool] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """

        with self.prepare_forward(inp, is_first_microbatch) as inp:
858
859
            assert self.fp8 or not self.primary_weights_in_fp8, \
                   "Need to run inside fp8_autocast region when weights are stored in FP8."
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879

            # Get concatenated weight and bias tensors
            if len(self.parameter_split_sizes) == 1:
                weight_tensor = getattr(self, self.weight_names[0])
                bias_tensor = getattr(self, self.bias_names[0])
            elif torch.is_grad_enabled():
                weight_tensor = _noop_cat(
                    [getattr(self, name) for name in self.weight_names],
                    self.weight_tensor,
                )
                if self.use_bias:
                    bias_tensor = _noop_cat(
                        [getattr(self, name) for name in self.bias_names],
                        self.bias_tensor,
                    )
                else:
                    bias_tensor = getattr(self, self.bias_names[0])  # Unused
            else:
                weight_tensor = self.weight_tensor
                bias_tensor = self.bias_tensor
880

881
882
883
884
885
            # Fetch the fp8 weights placeholders (for linear/gemm)
            weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
                is_first_microbatch
            )

886
887
            from ..cpu_offload import CPUOffloadEnabled

888
889
890
891
892
893
894
895
            if torch.is_grad_enabled():
                linear_fn = _Linear.apply
                args = []
            else:
                linear_fn = _Linear.forward
                args = [None]
            args += (
                weight_tensor,
896
897
                weight1_fp8,
                weight1_t_fp8,
898
899
900
901
902
903
904
905
                inp,
                bias_tensor,
                self.apply_bias and not self.gemm_bias_unfused_add,
                is_first_microbatch,
                self.fp8,
                self.fp8_calibration,
                self.fp8_meta,
                self.fuse_wgrad_accumulation,
906
                CPUOffloadEnabled,
907
908
909
910
911
912
913
                self.tp_group,
                self.tp_size,
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
                torch.is_grad_enabled(),
914
                self.primary_weights_in_fp8,
915
916
                self.ub_split_rs,
                self.ub_split_ag,
917
918
                self.ub_atomic_gemm_rs,
                self.ub_atomic_gemm_ag,
919
                self.ub_name,
920
921
922
923
924
925
926
927
928
            )
            out = linear_fn(*args)

        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        return out