"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "01a5d18a537b65a156cfa1a77706693a24c869c1"
linear.py 39 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Linear API"""
6
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
7
8
9
10
11
12
13
14
15
16
17
18
19

import torch

import transformer_engine_extensions as tex

from .base import (
    get_workspace,
    get_ub,
    TransformerEngineBaseModule,
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
20
from ._common import _noop_cat
21
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
22
23
24
from ..utils import (
    divide,
    cast_if_needed,
25
    assert_dim_for_fp8_exec,
26
    clear_tensor_data,
27
    init_method_constant,
28
29
30
31
32
33
34
)
from ..distributed import (
    set_tensor_model_parallel_attributes,
    get_distributed_world_size,
    allreduce,
    reduce_scatter_along_first_dim,
    gather_along_first_dim,
35
36
    is_fp8_activation_recompute_enabled,
    in_fp8_activation_recompute_phase,
37
38
39
40
41
42
43
44
)
from ..cpp_extensions import (
    fp8_gemm,
    gemm,
    fp8_cast_transpose_fused,
    cast_to_fp8,
)
from ..constants import GemmParallelModes, dist_group_type
45
from ..jit import no_torch_dynamo
46

47
48
from ..float8_tensor import Float8Tensor

49
50
51
52
53
54
55
56
57
58
59
__all__ = ["Linear"]


class _Linear(torch.autograd.Function):
    """Linear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
60
61
62
        weight: Union[Float8Tensor, torch.Tensor],
        weight_fp8: Union[Float8Tensor, None],
        weight_t_fp8: Union[Float8Tensor, None],
63
64
65
66
        inp: torch.Tensor,
        bias: torch.Tensor,
        use_bias: bool,
        is_first_microbatch: Union[bool, None],
67
        skip_fp8_weight_update: Union[torch.Tensor, None],
68
69
70
71
        fp8: bool,
        fp8_calibration: bool,
        fp8_meta: Dict[str, Any],
        fuse_wgrad_accumulation: bool,
72
        cpu_offloading: bool,
73
74
75
76
77
78
79
        tp_group: Union[dist_group_type, None],
        tp_size: int,
        sequence_parallel: bool,
        tensor_parallel: bool,
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
        is_grad_enabled: bool,
80
        primary_weights_in_fp8: bool,
81
82
        ub_overlap_rs: bool,
        ub_overlap_ag: bool,
83
84
        ub_name: str,
        dummy_tensor: torch.Tensor, # pylint: disable=unused-argument
85
86
87
88
89
    ) -> torch.Tensor:
        # Make sure input dimensions are compatible
        in_features = weight.shape[-1]
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.view((-1, in_features))
90
        if fp8:
91
92
            assert_dim_for_fp8_exec(inputmat)
            assert_dim_for_fp8_exec(weight)
93

94
95
96
97
98
99
        update_fp8_weights = (
            is_first_microbatch is None
            or is_first_microbatch
            or skip_fp8_weight_update is not None
        )

100
101
        tp_world_size = get_distributed_world_size(tp_group)
        ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs
102
103

        # Cast input to expected dtype
104
        inputmat = cast_if_needed(inputmat, activation_dtype)
105
        inputmat_t = None
106
107
108
        inputmat_no_fp8 = inputmat
        if fp8:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
109
110
111
112
113
114
115
116
117
118
119
120
121
            if (
                not fp8_meta["recipe"].override_linear_precision.wgrad
                and is_grad_enabled
                and weight.requires_grad
                and not sequence_parallel
            ):
                # FP8 input for forward, FP8 input transpose for backward wgrad
                inputmat, inputmat_t = fp8_cast_transpose_fused(
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    tex.FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
                )
122
            else:
123
124
                # FP8 input for forward
                inputmat = cast_to_fp8(
125
126
127
128
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    tex.FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
129
                )
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

        # Column Parallel Linear
        if parallel_mode == "column" and sequence_parallel:
            inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
        else:
            inputmat_total = inputmat

        if fp8:
            bias_dtype = (
                torch.bfloat16
                if activation_dtype == torch.float32
                else activation_dtype
            )
            bias = cast_if_needed(bias, bias_dtype) if use_bias else bias

145
146
147
148
149
150
151
152
153
154
155
            if primary_weights_in_fp8:
                # Weight is already in FP8
                weight.reset_fp8_meta_scale_inv()
                weight_fp8 = weight
            elif update_fp8_weights:
                # Need to cast weights to FP8
                weight_fp8 = Float8Tensor(
                    data=weight_fp8._data,
                    fp8_meta=fp8_meta,
                    fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
                )
156
157
158
                if (is_grad_enabled
                    or (is_fp8_activation_recompute_enabled()
                        and not in_fp8_activation_recompute_phase())):
159
160
161
162
163
                    fp8_cast_transpose_fused(
                        weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
164
165
                        cast_out=weight_fp8._data,
                        transpose_out=weight_t_fp8._data,
166
                        noop_flag=skip_fp8_weight_update,
167
168
                    )
                else:
169
                    cast_to_fp8(
170
171
172
173
                        weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
174
                        out=weight_fp8._data,
175
                    )
176
                    weight_t_fp8 = None
177

178
179
            proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
                None, None, None, activation_dtype)
180
            if ub_overlap_rs:
181
                ub_obj_projout = get_ub(ub_name+"_fprop")
182
183
184
185
186
                out = ub_obj_projout.get_ubuf_output(1)
                dim_size = list(inputmat_total.size())
                dim_size[0] = dim_size[0] // tp_world_size
                dim_size[1] = weight.size(0)
                rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
187
188
189
190
191
192
193
194
195
196
                if ub_obj_projout.is_p2p_overlap():
                    if ub_obj_projout.is_atomic_gemm():
                        ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P
                    else:
                        ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
                else:
                    if ub_obj_projout.is_atomic_gemm():
                        ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS
                    else:
                        ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
197
198
199
200
201
202
                if ub_obj_projout.is_fp8_ubuf():
                    proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT
                    meta_tensor = fp8_meta["scaling_fwd"]
                    proj_out_tetype = fp8_dtype_forward
                    proj_out_pttype = torch.uint8
                    ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index])
203
204
205
206
207
208
            else:
                dim_size = list(inputmat_total.size())
                dim_size[1] = weight.size(0)
                out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)

            _ = fp8_gemm(
209
                weight_fp8._data,
210
211
212
213
214
215
216
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_WEIGHT,
                fp8_dtype_forward,
                inputmat_total,
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_INPUT,
                fp8_dtype_forward,
217
                proj_out_pttype,
218
219
220
221
222
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
                use_split_accumulator=_2X_ACC_FPROP,
                out=out,
223
224
225
                ub_algo=ub_algo if ub_overlap_rs else None,
                ub=ub_obj_projout if ub_overlap_rs else None,
                extra_output_tensor=rs_out if ub_overlap_rs else None,
226
227
228
                out_index=proj_out_index,
                fp8_meta_tensor = meta_tensor,
                D_dtype = proj_out_tetype,
229
230
231
232
233
234
235
236
            )
        else:
            # Cast for native AMP
            weight = cast_if_needed(weight, activation_dtype)
            bias = cast_if_needed(bias, activation_dtype) if use_bias else bias

            if fp8_calibration:
                # amax of input
237
                amin, amax = inputmat_total.aminmax()
238
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
239
                    torch.max(-amin, amax).float()
240
                # amax of weight
241
                amin, amax = weight.aminmax()
242
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
243
                    torch.max(-amin, amax).float()
244

245
246
            if ub_overlap_rs:
                ub_obj_projout = get_ub(ub_name+"_fprop")
247
248
                out = ub_obj_projout.get_ubuf_output(1)
                dim_size = list(inputmat_total.size())
249
                dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group)
250
251
                dim_size[1] = weight.size(0)
                rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
252
253
254
255
                if ub_obj_projout.is_p2p_overlap():
                    ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
                else:
                    ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
256
257
258
259
260
            else:
                dim_size = list(inputmat_total.size())
                dim_size[1] = weight.size(0)
                out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)

261
            _ = gemm(
262
263
264
265
266
267
268
                weight,
                inputmat_total,
                activation_dtype,
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
                out=out,
269
270
271
                ub_algo=ub_algo if ub_overlap_rs else None,
                ub=ub_obj_projout if ub_overlap_rs else None,
                extra_output_tensor=rs_out if ub_overlap_rs else None,
272
273
274
            )

        if is_grad_enabled:
275
276
277
278
279
280
281
282
            saved_inputmat = None
            saved_inputmat_t = None
            if weight.requires_grad:
                if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad:
                    if inputmat_t is None:
                        saved_inputmat = inputmat
                    else:
                        saved_inputmat_t = inputmat_t
283
284
                        if cpu_offloading:
                            saved_inputmat_t.activation_offloading = True
285
286
                else:
                    saved_inputmat = inputmat_no_fp8
287
288
289
290

                if cpu_offloading:
                    if fuse_wgrad_accumulation:
                        weight.main_grad.weight_offloading = True
291
                    if fp8 and weight_t_fp8 is not None:
292
293
294
295
296
297
                        weight_t_fp8.weight_offloading = True
                    weight.weight_offloading = True

                    if saved_inputmat is not None:
                        saved_inputmat.activation_offloading = True

298
            ctx.save_for_backward(
299
300
                saved_inputmat,
                saved_inputmat_t,
301
                weight,
302
                weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
303
304
                weight_t_fp8 if fp8 else None,
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
305
                skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None,
306
307
308
309
310
            )
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
            ctx.fp8_meta = fp8_meta
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
311
            ctx.cpu_offloading = cpu_offloading
312
313
314
315
316
317
318
            ctx.is_first_microbatch = is_first_microbatch
            ctx.use_bias = use_bias
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
            ctx.inp_shape = inp.shape
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
319
            ctx.ub_overlap_ag = ub_overlap_ag
320
            ctx.ub_name = ub_name
321
322
            ctx.tp_size = tp_size
            ctx.requires_dgrad = inp.requires_grad
323
            ctx.primary_weights_in_fp8 = primary_weights_in_fp8
324
325

        # Row Parallel Linear
326
        if ub_overlap_rs:
327
328
329
330
331
332
333
334
335
336
337
338
339
340
            out = rs_out
        elif parallel_mode == "row" and sequence_parallel:
            out, _ = reduce_scatter_along_first_dim(out, tp_group)
        elif parallel_mode == "row" and tensor_parallel:
            out, _ = allreduce(out, tp_group)

        # [*, in_features] -> [*, out_features] except first dimension changes for SP
        return out.view(-1, *inp.shape[1:-1], out.shape[-1])


    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
341
        with torch.cuda.nvtx.range("_Linear_backward"):
342
343
344
345
            (
                inputmat,
                inputmat_t,
                weight,
346
                main_grad,
347
348
                weight_t_fp8,
                fwd_scale_inverses,
349
                skip_fp8_weight_update,
350
            ) = ctx.saved_tensors
351

352
353
354
355
            if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
                weight = torch.nn.Parameter(weight, False)
                weight.main_grad = main_grad

356
            # Primary weights are in FP8.
357
358
359
360
            if ctx.primary_weights_in_fp8:
                weight_t_fp8 = weight.transpose_2d(
                    cache=ctx.is_first_microbatch is not None,
                    noop_flag=skip_fp8_weight_update,
361
                )
362
363
364
            elif ctx.fp8:
                weight_t_fp8 = weight_t_fp8._data

365
366
367
            tp_world_size = get_distributed_world_size(ctx.tp_group)
            ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag
            if ctx.ub_overlap_ag:
368
369
                dim_size = list(grad_output.size())
                dim_size[0] = dim_size[0] * tp_world_size
370
                ctx.ub_obj_gradout = get_ub(ctx.ub_name+"_dgrad")
371
372
373
374
                if ctx.ub_obj_gradout.is_atomic_gemm():
                    ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
                else:
                    ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
375

376
377
378
379
380
381
382
383
384
385
386
            (
                grad_output,
                grad_output_c,
                grad_output_t,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
                ctx, grad_output, ctx.parallel_mode == "row"
            )

            # Column Parallel Linear
            # Overlap input AG with dgrad
387
388
389
            inputmat_total = None
            inputmat_t_total = None
            handle = None
390
            if weight.requires_grad and ctx.parallel_mode == "column" and ctx.sequence_parallel:
391
392
393
                inputmat_total, handle = gather_along_first_dim(
                    inputmat, ctx.tp_group, async_op=ctx.requires_dgrad
                )
394
395
            else:
                inputmat_total = inputmat
396
                inputmat_t_total = inputmat_t
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414

            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                )
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation

            if ctx.fp8:
                fp8_dtype_forward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=True
                )
                fp8_dtype_backward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=False
                )

            if ctx.requires_dgrad:
                if ctx.fp8:
415
                    dgrad, _ = fp8_gemm(
416
                        weight_t_fp8,
417
418
419
420
421
422
423
424
425
426
                        fwd_scale_inverses,
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                        grad_output_c,
                        ctx.fp8_meta["scaling_bwd"].scale_inv,
                        tex.FP8BwdTensors.GRAD_OUTPUT1,
                        fp8_dtype_backward,
                        ctx.activation_dtype,
                        get_workspace(),
                        use_split_accumulator=_2X_ACC_DGRAD,
427
428
                        ub_algo=ub_algo if ctx.ub_overlap_ag else None,
                        ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
429
430
431
432
433
434
435
436
437
                    )
                else:
                    dgrad, _, _ = gemm(
                        weight,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NN",
                        grad=True,
438
439
440
                        ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P \
                            if ctx.ub_overlap_ag else None,
                        ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
441
442
443
444
                    )

                # Overlap dgrad-RS/AR with wgrad
                if ctx.parallel_mode == "column" and ctx.sequence_parallel:
445
446
                    if handle is not None:
                        handle.wait()
447
448
449
450
451
452
453
454
455
456
                    dgrad, handle = reduce_scatter_along_first_dim(
                        dgrad, ctx.tp_group, async_op=True
                    )
                elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
                    dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)

            if weight.requires_grad:
                if ctx.fp8:
                    # WGRAD
                    if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
457
                        if ctx.ub_overlap_ag:
458
                            grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
459
460
                        if inputmat_t_total is None:
                            inputmat_t_total = tex.fp8_transpose(inputmat_total, fp8_dtype_backward)
461
                        wgrad, _ = fp8_gemm(
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
                            inputmat_t_total,
                            fwd_scale_inverses,
                            tex.FP8FwdTensors.GEMM1_INPUT,
                            fp8_dtype_forward,
                            grad_output_t,
                            ctx.fp8_meta["scaling_bwd"].scale_inv,
                            tex.FP8BwdTensors.GRAD_OUTPUT1,
                            fp8_dtype_backward,
                            ctx.activation_dtype,
                            get_workspace(),
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                            use_split_accumulator=_2X_ACC_WGRAD,
                        )
                    else:
                        wgrad, _, _ = gemm(
                            inputmat_total,
                            grad_output,
                            ctx.activation_dtype,
                            get_workspace(),
                            layout="NT",
                            grad=True,
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                        )
                else:
                    # WGRAD
                    wgrad, grad_bias, _ = gemm(
                        inputmat_total,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NT",
                        grad=True,
                        use_bias=ctx.use_bias,
                        accumulate=accumulate_wgrad_into_param_main_grad,
                        out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                    )
500
501
502
503

                # Deallocate input tensor
                clear_tensor_data(inputmat_total)
                clear_tensor_data(inputmat_t_total)
504
505
506
507
508
509
510
511

            # Column Parallel Linear
            if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
                handle.wait()

            if not ctx.use_bias:
                grad_bias = None

512
513
514
515
        if weight.requires_grad:
            # Handle custom DDP from mcore.
            if ctx.fuse_wgrad_accumulation and hasattr(weight, 'grad_added_to_main_grad'):
                weight.grad_added_to_main_grad = True
516
517
518
519
520
521
522
523
524
525
526
527
                if getattr(weight, 'zero_out_wgrad', False):
                    wgrad = torch.zeros(weight.main_grad.shape,
                                        dtype=weight.dtype,
                                        device=torch.cuda.current_device(),
                                        requires_grad=False
                                       )
                else:
                    wgrad = torch.empty(weight.main_grad.shape,
                                        dtype=weight.dtype,
                                        device=torch.cuda.current_device(),
                                        requires_grad=False
                                       )
528
529
530
531
            elif ctx.fuse_wgrad_accumulation:
                wgrad = None
        else:
            wgrad = None
532

533
        return (
534
            wgrad,
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
            None,
            None,
            dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
            grad_bias,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
554
555
            None,
            None,
556
            None,
557
558
            None,
            None,
559
560
561
562
        )


class Linear(TransformerEngineBaseModule):
563
    """Applies a linear transformation to the incoming data :math:`y = xA^T + b`
564
565
566
567
568
569
570
571
572
573
574
575
576
577

    On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
cyanguwa's avatar
cyanguwa committed
578
    parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
579
580
581
582
583
584
585
                      Configuration for splitting the weight and bias tensors along dim 0 into
                      multiple PyTorch parameters. If a list or tuple of strings is provided,
                      they are used to make the names of equally-sized parameters. If a dict
                      (preferably an OrderedDict) is provided, the keys are used as names and
                      values as split sizes along dim 0. The resulting parameters will have
                      names that end in `_weight` or `_bias`, so trailing underscores are
                      stripped from any provided names.
586
587
588
589
    device : Union[torch.device, str], default = "cuda"
          The device on which the parameters of the model will allocated. It is the user's
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
    parallel_mode : {None, 'Column', 'Row'}, default = `None`
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
621
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
622
623
624
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
625

626
627
628
629
630
631
632
633
634
635
636
637
638
639
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
640
        params_dtype: Optional[torch.dtype] = None,
641
        parallel_mode: Optional[str] = None,
cyanguwa's avatar
cyanguwa committed
642
        parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
643
        device: Union[torch.device, str] = "cuda",
644
645
        ub_overlap_rs: bool = False,
        ub_overlap_ag: bool = False,
646
        ub_name: Optional[str] = None,
647
648
    ) -> None:
        super().__init__()
649
650

        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
651
652
653
654
655
656
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
        self.apply_bias = bias and not return_bias
657
        self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
658
659
660
        self.ub_overlap_rs = ub_overlap_rs
        self.ub_overlap_ag = ub_overlap_ag
        if ub_overlap_rs or ub_overlap_ag:
661
            assert ub_name is not None, "Userbuffer name [string] is not set."
662
663
664
            assert (
                tex.userbuf_comm_available()
            ), "Userbuffer communication backend not available."
665
        self.ub_name = ub_name
666
667
668
669
        self.get_rng_state_tracker = get_rng_state_tracker
        if device == 'meta':
            assert parameters_split is None, ("Cannot split module parameters "
                                              "on 'meta' device.")
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

691
        self.weight_tensor = torch.empty(
692
            self.out_features, self.in_features,
693
            device=device, dtype=params_dtype)
694
695

        if self.use_bias:
696
            self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype)
697
        else:
698
            self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device)
699

700
701
702
703
        # Configure parameter splits
        self.weight_names = []
        self.bias_names = []
        self.parameter_split_sizes = []
704
        if parameters_split is None:
705
706
707
708
709
710
            # Split into a single parameter by default
            self.weight_names = ["weight"]
            self.bias_names = ["bias"]
            self.parameter_split_sizes = [out_features]
        elif not parameters_split:
            raise ValueError("Cannot split weight buffer into 0 parameters")
cyanguwa's avatar
cyanguwa committed
711
        elif isinstance(parameters_split, dict):
712
713
714
715
716
717
718
719
720
721
722
723
            # Split parameters with provided sizes
            for name, split_size in parameters_split.items():
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
        elif all(isinstance(name, str) for name in parameters_split):
            # Split parameters evenly
            split_size = out_features // len(parameters_split)
            for name in parameters_split:
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
cyanguwa's avatar
cyanguwa committed
724
        else:
725
            raise TypeError("Invalid configuration for parameters split")
726

727
728
729
730
731
732
        # Make sure parameter splits are valid
        if sum(self.parameter_split_sizes) != out_features:
            raise ValueError(
                f"Trying to split weight buffer ({out_features=}) "
                f"with split sizes {self.parameter_split_sizes}"
            )
733

734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
        # Adjust parameter splits for tensor-parallel distribution
        if self.parallel_mode == "column":
            for i, size in enumerate(self.parameter_split_sizes):
                if size % self.tp_size != 0:
                    raise RuntimeError(
                        f"Attempting to distribute a parameter with out_features={size} "
                        f"between {self.tp_size} tensor-parallel processes"
                    )
                self.parameter_split_sizes[i] = size // self.tp_size

        # Construct parameters from weight and bias buffers
        offset = 0
        for i, split_size in enumerate(self.parameter_split_sizes):
            split_start = offset
            offset += split_size
            split_end = offset

            # Check if parameters are subviews of buffers
            is_subview = (split_start, split_end) != (0, self.out_features)
            if is_subview and self.primary_weights_in_fp8:
                raise RuntimeError(
                    "Splitting Float8Tensor into multiple params "
                    "is not supported"
                )
758

759
760
761
762
763
            # Construct weight parameter
            weight = self.weight_tensor
            if is_subview:
                weight = weight[split_start:split_end]
            weight = torch.nn.Parameter(weight)
764
765
766
767
            self.register_parameter(self.weight_names[i], weight,
                                    init_fn=init_method,
                                    get_rng_state_tracker=get_rng_state_tracker,
                                    fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT)
cyanguwa's avatar
cyanguwa committed
768

769
770
771
772
773
774
            # Construct bias parameter if needed
            if self.use_bias:
                bias = self.bias_tensor
                if is_subview:
                    bias = bias[split_start:split_end]
                bias = torch.nn.Parameter(bias)
775
776
                self.register_parameter(self.bias_names[i], bias,
                                        init_fn=init_method_constant(0.0))
777
            else:
778
779
                bias = torch.Tensor().to(dtype=params_dtype, device=device)
                setattr(self, self.bias_names[i], bias)
780

781
782
783
784
785
            # Concatenated tensors are not needed if not splitting
            # into multiple parameters
            if not is_subview:
                del self.weight_tensor
                del self.bias_tensor
cyanguwa's avatar
cyanguwa committed
786

787
788
789
790
791
        if self.primary_weights_in_fp8:
            self.init_fp8_metadata()

        self.reset_parameters(defer_init=(device == 'meta'))

792
793
794
795
796
797
798
799
800
        self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))

        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.apply_bias:
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

801
802
803
804
        # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
        self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
        FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)

805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
    def reset_parameters(self, defer_init=False):
        super().reset_parameters(defer_init=defer_init)

        if not defer_init:
            # Set parallelism attributes for linear weights
            for weight in self.weight_names:
                set_tensor_model_parallel_attributes(
                    tensor=getattr(self, weight),
                    is_parallel=True,
                    dim=1 if self.parallel_mode == "row" else 0,
                    stride=1,
                )

            # Set parallelism attributes for linear biases
            if self.use_bias:
                for bias in self.bias_names:
                    if self.parallel_mode == "row":
                        setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel)
                    elif self.parallel_mode == "column":
                        set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)

826
827
828
    def get_fp8_weights_scratchpad(
        self,
        is_first_microbatch: Union[bool, None],
829
    ) -> List[Float8Tensor]:
830
831
832
833
834
        """
        Fetch the fp8 weight tensor placeholders if they exist (when
        `is_first_microbatch` is not `None`) or return empty fp8 weight
        tensors (if `is_first_microbatch is None`)
        """
835
        if not self.fp8 or self.primary_weights_in_fp8:
836
837
838
839
840
841
842
843
844
845
846
847
848
849
            return [None, None]

        if is_first_microbatch is None:
            # Return empty weight placeholders for each fwd/bwd pass
            fp8_weight_tensors = self.get_fp8_weights_empty_tensors(
                is_first_microbatch
            )
        else:
            # These persistent weight placeholders should've been created in
            # `set_fp8_weights` method
            fp8_weight_tensors = [self.weight1_fp8, self.weight1_t_fp8]

        return fp8_weight_tensors

850
    @no_torch_dynamo()
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
    def forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Optional[bool] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """

878
879
880
881
        skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
        if skip_fp8_weight_update is not None:
            is_first_microbatch = False

882
        with self.prepare_forward(inp, is_first_microbatch) as inp:
883
884
            assert self.fp8 or not self.primary_weights_in_fp8, \
                   "Need to run inside fp8_autocast region when weights are stored in FP8."
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904

            # Get concatenated weight and bias tensors
            if len(self.parameter_split_sizes) == 1:
                weight_tensor = getattr(self, self.weight_names[0])
                bias_tensor = getattr(self, self.bias_names[0])
            elif torch.is_grad_enabled():
                weight_tensor = _noop_cat(
                    [getattr(self, name) for name in self.weight_names],
                    self.weight_tensor,
                )
                if self.use_bias:
                    bias_tensor = _noop_cat(
                        [getattr(self, name) for name in self.bias_names],
                        self.bias_tensor,
                    )
                else:
                    bias_tensor = getattr(self, self.bias_names[0])  # Unused
            else:
                weight_tensor = self.weight_tensor
                bias_tensor = self.bias_tensor
905

906
907
908
909
910
            # Fetch the fp8 weights placeholders (for linear/gemm)
            weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
                is_first_microbatch
            )

911
912
            from ..cpu_offload import CPUOffloadEnabled

913
914
915
916
917
918
919
920
            if torch.is_grad_enabled():
                linear_fn = _Linear.apply
                args = []
            else:
                linear_fn = _Linear.forward
                args = [None]
            args += (
                weight_tensor,
921
922
                weight1_fp8,
                weight1_t_fp8,
923
924
925
926
                inp,
                bias_tensor,
                self.apply_bias and not self.gemm_bias_unfused_add,
                is_first_microbatch,
927
                skip_fp8_weight_update,
928
929
930
931
                self.fp8,
                self.fp8_calibration,
                self.fp8_meta,
                self.fuse_wgrad_accumulation,
932
                CPUOffloadEnabled,
933
934
935
936
937
938
939
                self.tp_group,
                self.tp_size,
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
                torch.is_grad_enabled(),
940
                self.primary_weights_in_fp8,
941
942
                self.ub_overlap_rs,
                self.ub_overlap_ag,
943
                self.ub_name,
944
                self.dummy_tensor,
945
946
947
948
949
950
951
952
953
            )
            out = linear_fn(*args)

        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        return out