linear.py 34.9 KB
Newer Older
1
2
3
4
5
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Linear API"""
6
import warnings
7
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

import torch
from torch.nn.parameter import Parameter

import transformer_engine_extensions as tex

from .base import (
    get_workspace,
    _prepare_backward,
    get_ub,
    TransformerEngineBaseModule,
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
23
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
24
25
26
27
from ..utils import (
    divide,
    get_default_init_method,
    cast_if_needed,
28
    assert_dim_for_fp8_exec,
29
    clear_tensor_data,
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
)
from ..distributed import (
    set_tensor_model_parallel_attributes,
    get_distributed_world_size,
    allreduce,
    initialize_affine_weight_gpu,
    reduce_scatter_along_first_dim,
    gather_along_first_dim,
    gather_along_last_dim,
)
from ..cpp_extensions import (
    fp8_gemm,
    gemm,
    fp8_cast_transpose_fused,
    cast_to_fp8,
)
from ..constants import GemmParallelModes, dist_group_type
47
from ..jit import no_torch_dynamo
48

49
50
from ..float8_tensor import Float8Tensor

51
52
53
54
55
56
57
58
59
60
61
62

__all__ = ["Linear"]


class _Linear(torch.autograd.Function):
    """Linear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
63
64
65
        weight: Union[Float8Tensor, torch.Tensor],
        weight_fp8: Union[Float8Tensor, None],
        weight_t_fp8: Union[Float8Tensor, None],
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        inp: torch.Tensor,
        bias: torch.Tensor,
        use_bias: bool,
        is_first_microbatch: Union[bool, None],
        fp8: bool,
        fp8_calibration: bool,
        fp8_meta: Dict[str, Any],
        fuse_wgrad_accumulation: bool,
        tp_group: Union[dist_group_type, None],
        tp_size: int,
        sequence_parallel: bool,
        tensor_parallel: bool,
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
        is_grad_enabled: bool,
81
        primary_weights_in_fp8: bool,
82
83
        ub_split_rs: bool,
        ub_split_ag: bool,
84
85
        ub_atomic_gemm_rs: bool,
        ub_atomic_gemm_ag: bool,
86
        ub_name: str,
87
88
89
90
91
    ) -> torch.Tensor:
        # Make sure input dimensions are compatible
        in_features = weight.shape[-1]
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.view((-1, in_features))
92
        if fp8:
93
94
            assert_dim_for_fp8_exec(inputmat)
            assert_dim_for_fp8_exec(weight)
95
96
97

        update_fp8_weights = is_first_microbatch is None or is_first_microbatch

98
        if ub_split_rs or ub_atomic_gemm_rs:
99
100
101
            tp_world_size = get_distributed_world_size(tp_group)
            if tp_world_size == 1:
                ub_split_rs = False
102
103
104
                ub_atomic_gemm_rs = False
        if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
            assert fp8, "AtomicGemm overlap supported only for FP8 GEMM."
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
        # Cast for native AMP
        inputmat = cast_if_needed(inputmat, activation_dtype)
        inputmat_no_fp8 = inputmat

        if fp8:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)

            if not fp8_meta["recipe"].override_linear_precision.wgrad:
                if is_grad_enabled:
                    inputmat, inputmat_t = fp8_cast_transpose_fused(
                        inputmat,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
                    )
                else:
                    inputmat = cast_to_fp8(
                        inputmat,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
                    )
            else:
                inputmat, inputmat_t = cast_to_fp8(
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    tex.FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
                ), None

        # Column Parallel Linear
        if parallel_mode == "column" and sequence_parallel:
            inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
        else:
            inputmat_total = inputmat

        if fp8:
            bias_dtype = (
                torch.bfloat16
                if activation_dtype == torch.float32
                else activation_dtype
            )
            bias = cast_if_needed(bias, bias_dtype) if use_bias else bias

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
            if primary_weights_in_fp8:
                # Weight is already in FP8
                weight.reset_fp8_meta_scale_inv()
                weight_fp8 = weight
                weight_t_fp8 = None
                if is_grad_enabled:
                    weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch)

            elif update_fp8_weights:
                # Need to cast weights to FP8
                weight_fp8 = Float8Tensor(
                    data=weight_fp8._data,
                    fp8_meta=fp8_meta,
                    fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
                )
164
165
166
167
168
169
                if is_grad_enabled:
                    fp8_cast_transpose_fused(
                        weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
170
171
                        cast_out=weight_fp8._data,
                        transpose_out=weight_t_fp8._data,
172
173
                    )
                else:
174
                    weight_fp8._data = cast_to_fp8(
175
176
177
178
179
                        weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                    )
180
                    weight_t_fp8 = None
181

182
183
184
            proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
                None, None, None, activation_dtype)
            if ub_split_rs or ub_atomic_gemm_rs:
185
                ub_obj_projout = get_ub(ub_name+"_fprop")
186
187
188
189
190
                out = ub_obj_projout.get_ubuf_output(1)
                dim_size = list(inputmat_total.size())
                dim_size[0] = dim_size[0] // tp_world_size
                dim_size[1] = weight.size(0)
                rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
191
192
193
194
195
196
197

                if ub_obj_projout.is_fp8_ubuf():
                    proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT
                    meta_tensor = fp8_meta["scaling_fwd"]
                    proj_out_tetype = fp8_dtype_forward
                    proj_out_pttype = torch.uint8
                    ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index])
198
199
200
201
202
            else:
                dim_size = list(inputmat_total.size())
                dim_size[1] = weight.size(0)
                out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)

203
204
            ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None
            ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo
205
            _ = fp8_gemm(
206
                weight_fp8._data,
207
208
209
210
211
212
213
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_WEIGHT,
                fp8_dtype_forward,
                inputmat_total,
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_INPUT,
                fp8_dtype_forward,
214
                proj_out_pttype,
215
216
217
218
219
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
                use_split_accumulator=_2X_ACC_FPROP,
                out=out,
220
221
222
223
224
225
                ub_algo=ub_algo,
                ub=ub_obj_projout if (ub_split_rs or ub_atomic_gemm_rs) else None,
                extra_output_tensor=rs_out if (ub_split_rs or ub_atomic_gemm_rs) else None,
                out_index=proj_out_index,
                fp8_meta_tensor = meta_tensor,
                D_dtype = proj_out_tetype,
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            )
        else:
            # Cast for native AMP
            weight = cast_if_needed(weight, activation_dtype)
            bias = cast_if_needed(bias, activation_dtype) if use_bias else bias

            if fp8_calibration:
                # amax of input
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
                    torch.amax(inputmat_total).float()
                # amax of weight
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
                    torch.amax(weight).float()

            if ub_split_rs:
                ub_obj_projout = get_ub("proj_fprop")
                out = ub_obj_projout.get_ubuf_output(1)
                dim_size = list(inputmat_total.size())
                dim_size[0] = dim_size[0] // tp_world_size
                dim_size[1] = weight.size(0)
                rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
            else:
                dim_size = list(inputmat_total.size())
                dim_size[1] = weight.size(0)
                out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)

252
            _ = gemm(
253
254
255
256
257
258
259
260
261
262
263
264
265
266
                weight,
                inputmat_total,
                activation_dtype,
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
                out=out,
                ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
                ub=ub_obj_projout if ub_split_rs else None,
                extra_output_tensor=rs_out if ub_split_rs else None,
            )

        if is_grad_enabled:
            fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
267
268
269
            if fp8:
                assert hasattr(weight_t_fp8, "_data"), \
                       "_data attr doesn't exist (before save for bwd)"
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
            ctx.save_for_backward(
                inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None,
                inputmat_t if weight.requires_grad and fp8_wgrad else None,
                weight,
                weight_t_fp8 if fp8 else None,
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
            )
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
            ctx.fp8_meta = fp8_meta
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
            ctx.is_first_microbatch = is_first_microbatch
            ctx.use_bias = use_bias
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
            ctx.inp_shape = inp.shape
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
            ctx.ub_split_ag = ub_split_ag
289
            ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag
290
            ctx.ub_name = ub_name
291
292
293
294
            ctx.tp_size = tp_size
            ctx.requires_dgrad = inp.requires_grad

        # Row Parallel Linear
295
        if ub_split_rs or ub_atomic_gemm_rs:
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
            out = rs_out
        elif parallel_mode == "row" and sequence_parallel:
            out, _ = reduce_scatter_along_first_dim(out, tp_group)
        elif parallel_mode == "row" and tensor_parallel:
            out, _ = allreduce(out, tp_group)

        # [*, in_features] -> [*, out_features] except first dimension changes for SP
        return out.view(-1, *inp.shape[1:-1], out.shape[-1])


    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        with _prepare_backward(
            ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear"
        ):
            (
                inputmat,
                inputmat_t,
                weight,
                weight_t_fp8,
                fwd_scale_inverses,
            ) = ctx.saved_tensors
320
321
322
            if weight_t_fp8 is not None:
                assert hasattr(weight_t_fp8, "_data"), \
                       "_data attr doesn't exist (after restore in bwd)"
323

324
            if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
325
326
327
                tp_world_size = get_distributed_world_size(ctx.tp_group)
                if tp_world_size == 1:
                    ctx.ub_split_ag = False
328
329
                    ctx.ub_atomic_gemm_ag = False
            if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
330
331
                dim_size = list(grad_output.size())
                dim_size[0] = dim_size[0] * tp_world_size
332
                ctx.ub_obj_gradout = get_ub(ctx.ub_name+"_dgrad")
333
334
335
336
337
338
339
340
341
342
343
            (
                grad_output,
                grad_output_c,
                grad_output_t,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
                ctx, grad_output, ctx.parallel_mode == "row"
            )

            # Column Parallel Linear
            # Overlap input AG with dgrad
344
            if weight.requires_grad and ctx.parallel_mode == "column" and ctx.sequence_parallel:
345
346
347
348
349
350
351
352
353
354
355
                if ctx.fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                    inputmat_t_total, handle = gather_along_last_dim(
                        inputmat_t, ctx.tp_group, async_op=ctx.requires_dgrad
                    )
                else:
                    inputmat_total, handle = gather_along_first_dim(
                        inputmat, ctx.tp_group, async_op=ctx.requires_dgrad
                    )
            else:
                inputmat_t_total = inputmat_t
                inputmat_total = inputmat
356
                handle = None
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372

            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                )
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation

            if ctx.fp8:
                fp8_dtype_forward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=True
                )
                fp8_dtype_backward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=False
                )

373
374
            ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None
            ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo
375
376
            if ctx.requires_dgrad:
                if ctx.fp8:
377
                    dgrad, _ = fp8_gemm(
378
                        weight_t_fp8._data,
379
380
381
382
383
384
385
386
387
388
                        fwd_scale_inverses,
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                        grad_output_c,
                        ctx.fp8_meta["scaling_bwd"].scale_inv,
                        tex.FP8BwdTensors.GRAD_OUTPUT1,
                        fp8_dtype_backward,
                        ctx.activation_dtype,
                        get_workspace(),
                        use_split_accumulator=_2X_ACC_DGRAD,
389
390
                        ub_algo=ub_algo,
                        ub=ctx.ub_obj_gradout if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag else None,
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
                    )
                else:
                    dgrad, _, _ = gemm(
                        weight,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NN",
                        grad=True,
                        ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
                        ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
                    )

                # Overlap dgrad-RS/AR with wgrad
                if ctx.parallel_mode == "column" and ctx.sequence_parallel:
406
407
                    if handle is not None:
                        handle.wait()
408
409
410
411
412
413
414
415
416
417
                    dgrad, handle = reduce_scatter_along_first_dim(
                        dgrad, ctx.tp_group, async_op=True
                    )
                elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
                    dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)

            if weight.requires_grad:
                if ctx.fp8:
                    # WGRAD
                    if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
418
                        if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
419
                            grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
420
                        wgrad, _ = fp8_gemm(
421
422
423
424
425
426
427
428
429
430
431
432
433
434
                            inputmat_t_total,
                            fwd_scale_inverses,
                            tex.FP8FwdTensors.GEMM1_INPUT,
                            fp8_dtype_forward,
                            grad_output_t,
                            ctx.fp8_meta["scaling_bwd"].scale_inv,
                            tex.FP8BwdTensors.GRAD_OUTPUT1,
                            fp8_dtype_backward,
                            ctx.activation_dtype,
                            get_workspace(),
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                            use_split_accumulator=_2X_ACC_WGRAD,
                        )
435
                        clear_tensor_data(inputmat_t_total)
436
437
438
439
440
441
442
443
444
445
446
                    else:
                        wgrad, _, _ = gemm(
                            inputmat_total,
                            grad_output,
                            ctx.activation_dtype,
                            get_workspace(),
                            layout="NT",
                            grad=True,
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                        )
447
                        clear_tensor_data(inputmat_total)
448
449
450
451
452
453
454
455
456
457
458
459
460
                else:
                    # WGRAD
                    wgrad, grad_bias, _ = gemm(
                        inputmat_total,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NT",
                        grad=True,
                        use_bias=ctx.use_bias,
                        accumulate=accumulate_wgrad_into_param_main_grad,
                        out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                    )
461
                    clear_tensor_data(inputmat_total)
462
463
464
465
466
467
468
469

            # Column Parallel Linear
            if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
                handle.wait()

            if not ctx.use_bias:
                grad_bias = None

470
471
472
473
        if weight.requires_grad:
            # Handle custom DDP from mcore.
            if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'):
                weight.grad_added_to_main_grad = True
474
475
476
477
478
                wgrad = torch.empty(weight.main_grad.shape,
                                   dtype=weight.dtype,
                                   device=torch.cuda.current_device(),
                                   requires_grad=False
                                   )
479
480
481
482
            elif ctx.fuse_wgrad_accumulation:
                wgrad = None
        else:
            wgrad = None
483

484
        return (
485
            wgrad,
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
            None,
            None,
            dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
            grad_bias,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
505
506
            None,
            None,
507
            None,
508
            None,
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
        )


class Linear(TransformerEngineBaseModule):
    """
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`

    On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
cyanguwa's avatar
cyanguwa committed
529
530
531
532
533
534
535
536
    parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
                      if a tuple of strings or a dict of strings to integers is provided,
                      the weight and bias parameters of the module are exposed as `N` separate
                      `torch.nn.parameter.Parameter`s each, split along the first dimension,
                      where `N` is the length of the argument and the strings contained are the
                      names of the split parameters. In the case of a tuple, each parameter
                      has the same shape. In the case of a dict, the values give the
                      `out_features` for each projection.
537
538
539
540
    device : Union[torch.device, str], default = "cuda"
          The device on which the parameters of the model will allocated. It is the user's
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
    parallel_mode : {None, 'Column', 'Row'}, default = `None`
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
572
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
590
        params_dtype: Optional[torch.dtype] = None,
591
        parallel_mode: Optional[str] = None,
cyanguwa's avatar
cyanguwa committed
592
        parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
593
        device: Union[torch.device, str] = "cuda",
594
595
        ub_split_rs: bool = False,
        ub_split_ag: bool = False,
596
597
        ub_atomic_gemm_rs: bool = False,
        ub_atomic_gemm_ag: bool = False,
598
        ub_name: Optional[str] = None,
599
600
    ) -> None:
        super().__init__()
601
602

        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
603
604
605
606
607
608
609
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
        self.apply_bias = bias and not return_bias
        self.parameters_split = parameters_split
610
        self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
611
612
        self.ub_split_rs = ub_split_rs
        self.ub_split_ag = ub_split_ag
613
614
        self.ub_atomic_gemm_rs = ub_atomic_gemm_rs
        self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
615
616
617
        if any([ub_atomic_gemm_rs, ub_atomic_gemm_ag]):
            assert ub_name is not None, "Userbuffer name [string] is not set."
        self.ub_name = ub_name
618

619
        if ub_split_rs or ub_split_ag or ub_atomic_gemm_rs:
620
621
622
623
            assert (
                tex.userbuf_comm_available()
            ), "Userbuffer communication backend not available."

624
625
626
627
628
        if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
            warnings.warn(
                "Atomic gemm uses a beta API from cublas and is not tested for all use cases."
            )

629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        if init_method is None:
            init_method = get_default_init_method()

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

653
        temp_weight = torch.empty(
654
            self.out_features, self.in_features,
655
            device=device, dtype=params_dtype)
656

657
        # TODO(ksivaman): This functionality works with FP8 outside TE.
658
        initialize_affine_weight_gpu(
659
            temp_weight,
660
661
662
663
664
665
            init_method,
            get_rng_state_tracker,
            partition_dim=1 if self.parallel_mode == "row" else 0,
            stride=1,
        )

666
667
668
669
670
671
672
673
674
675
676
677
        if self.primary_weights_in_fp8:
            self.init_fp8_metadata()
            self.fp8_meta["update_amax_and_scale_fwd"] = True

            self.weight_tensor = Float8Tensor.to_float8(
                temp_weight,
                fp8_meta=self.fp8_meta,
                fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
            )
        else:
            self.weight_tensor = temp_weight

678
        if self.use_bias:
679
            self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype)
680
        else:
681
            self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device)
682

683
684
        with torch.no_grad():
            self.bias_tensor.zero_()
685

686
        if parameters_split is None:
cyanguwa's avatar
cyanguwa committed
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
            parameters_split = {"": self.out_features}
        elif isinstance(parameters_split, tuple):
            assert (
                self.out_features % len(parameters_split) == 0
            ), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
            split_size = self.out_features // len(parameters_split)
            parameters_split = {key: split_size for key in parameters_split}
        elif isinstance(parameters_split, dict):
            overall_split_size = sum(parameters_split.values())
            assert(
                self.out_features == overall_split_size
            ), f"Overall sum of parameters_split (={overall_split_size}) does not match "\
               f"to out features (={self.out_features})"
        else:
            assert False, "Type of 'parameters_split' is not None, tuple or dict"
        self.updated_parameters_split = parameters_split
703

704
705
        self.weight_names = []
        self.bias_names = []
706

cyanguwa's avatar
cyanguwa committed
707
708
        slice_begin = 0
        for pname, slice_size in parameters_split.items():
709
710
            wname = pname + "weight"
            bname = pname + "bias"
711

cyanguwa's avatar
cyanguwa committed
712
713
            slice_end = slice_begin + slice_size

714
715
716
717
718
719
720
721
722
723
724
            # TODO(ksivaman): Add indexing op to torch dispatcher for float8
            if self.primary_weights_in_fp8:
                assert len(parameters_split) == 1, ("Slicing operation is not "
                                                    "supported in Float8Tensor "
                                                    "class!")
                self.register_parameter(wname, Parameter(self.weight_tensor))
            else:

                self.register_parameter(
                    wname, Parameter(self.weight_tensor[slice_begin:slice_end])
                )
725

726
727
728
729
730
731
            set_tensor_model_parallel_attributes(
                tensor=getattr(self, wname),
                is_parallel=True,
                dim=1 if parallel_mode == "row" else 0,
                stride=1,
            )
732

733
            if self.use_bias:
734
                self.register_parameter(
cyanguwa's avatar
cyanguwa committed
735
                    bname, Parameter(self.bias_tensor[slice_begin:slice_end])
736
                )
737
738
                if parallel_mode == "row":
                    setattr(getattr(self, bname), "sequence_parallel", sequence_parallel)
739
            else:
740
                setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device))
741

742
743
            if parallel_mode == "column":
                set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
744

745
746
            self.weight_names.append(wname)
            self.bias_names.append(bname)
747

cyanguwa's avatar
cyanguwa committed
748
749
            slice_begin = slice_end

750
751
752
753
754
755
756
757
758
        self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))

        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.apply_bias:
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

759
760
761
    def get_fp8_weights_scratchpad(
        self,
        is_first_microbatch: Union[bool, None],
762
    ) -> List[Float8Tensor]:
763
764
765
766
767
        """
        Fetch the fp8 weight tensor placeholders if they exist (when
        `is_first_microbatch` is not `None`) or return empty fp8 weight
        tensors (if `is_first_microbatch is None`)
        """
768
        if not self.fp8 or self.primary_weights_in_fp8:
769
770
771
772
773
774
775
776
777
778
779
780
781
782
            return [None, None]

        if is_first_microbatch is None:
            # Return empty weight placeholders for each fwd/bwd pass
            fp8_weight_tensors = self.get_fp8_weights_empty_tensors(
                is_first_microbatch
            )
        else:
            # These persistent weight placeholders should've been created in
            # `set_fp8_weights` method
            fp8_weight_tensors = [self.weight1_fp8, self.weight1_t_fp8]

        return fp8_weight_tensors

783
    @no_torch_dynamo
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
    def forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Optional[bool] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """

        with self.prepare_forward(inp, is_first_microbatch) as inp:
812
813
            assert self.fp8 or not self.primary_weights_in_fp8, \
                   "Need to run inside fp8_autocast region when weights are stored in FP8."
814
            bias_tensor = (
815
                self.bias if self.parameters_split is None
816
                else self.bias_tensor if not torch.is_grad_enabled()
cyanguwa's avatar
cyanguwa committed
817
818
                else self.noop_cat("bias_tensor", self.bias_names,
                    self.updated_parameters_split)
819
820
            )
            weight_tensor = (
821
                self.weight if self.parameters_split is None
822
                else self.weight_tensor if not torch.is_grad_enabled()
cyanguwa's avatar
cyanguwa committed
823
824
                else self.noop_cat("weight_tensor", self.weight_names,
                    self.updated_parameters_split)
825
826
            )

827
828
829
830
831
            # Fetch the fp8 weights placeholders (for linear/gemm)
            weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
                is_first_microbatch
            )

832
833
834
835
836
837
838
839
            if torch.is_grad_enabled():
                linear_fn = _Linear.apply
                args = []
            else:
                linear_fn = _Linear.forward
                args = [None]
            args += (
                weight_tensor,
840
841
                weight1_fp8,
                weight1_t_fp8,
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
                inp,
                bias_tensor,
                self.apply_bias and not self.gemm_bias_unfused_add,
                is_first_microbatch,
                self.fp8,
                self.fp8_calibration,
                self.fp8_meta,
                self.fuse_wgrad_accumulation,
                self.tp_group,
                self.tp_size,
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
                torch.is_grad_enabled(),
857
                self.primary_weights_in_fp8,
858
859
                self.ub_split_rs,
                self.ub_split_ag,
860
861
                self.ub_atomic_gemm_rs,
                self.ub_atomic_gemm_ag,
862
                self.ub_name,
863
864
865
866
867
868
869
870
871
            )
            out = linear_fn(*args)

        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        return out