linear.py 7.82 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.

"""Fusible operation for linear layer."""

from __future__ import annotations
from collections.abc import Callable
9
from typing import Any, Optional
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

import torch

from transformer_engine.pytorch.ops.basic import (
    AllReduce,
    BasicLinear,
    Bias,
    ReduceScatter,
)
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.ops.op import FusedOperation


class Linear(FusedOperation):
    """Apply linear transformation: :math:`y = x A^T + b`

    This is a drop-in replacement for `torch.nn.Linear`.

    Parameters
    ----------
    in_features: int
        Inner dimension of input tensor
    out_features: int
        Inner dimension of output tensor
    bias: bool, default = `True`
        Apply additive bias
    device: torch.device, default = default CUDA device
        Tensor device
    dtype: torch.dtype, default = default dtype
        Tensor datatype
    tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
        Mode for tensor parallelism
    tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
        Process group for tensor parallelism
    sequence_parallel: bool, default = `False`
        Whether to apply sequence parallelism together with tensor
        parallelism, i.e. distributing input or output tensors along
        outer dimension (sequence or batch dim) when not distributing
        along inner dimension (embedding dim)
    rng_state_tracker_function: callable
        Function that returns CudaRNGStatesTracker, which is used for
        model-parallel weight initialization
    accumulate_into_main_grad: bool, default = `False`
        Whether to directly accumulate weight gradients into the
        weight's `main_grad` attribute instead of relying on PyTorch
        autograd. The weight's `main_grad` must be set externally and
        there is no guarantee that `grad` will be set or be
57
58
        meaningful. This is primarily intented to integrate with
        Megatron-LM.
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        *,
        bias: bool = True,
        device: Optional[torch.device | str] = None,
        dtype: Optional[torch.dtype] = None,
        tensor_parallel_mode: Optional[str] = None,
        tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
        sequence_parallel: bool = False,
        rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None,
        accumulate_into_main_grad: bool = False,
    ) -> None:

        # Tensor parallel configuration
        (
            tensor_parallel_mode,
            tensor_parallel_group,
            tensor_parallel_size,
            sequence_parallel,
            local_in_features,
            local_out_features,
        ) = BasicLinear._canonicalize_tensor_parallelism(
            mode=tensor_parallel_mode,
            process_group=tensor_parallel_group,
            sequence_parallel=sequence_parallel,
            in_features=in_features,
            out_features=out_features,
        )

        # Construct basic ops
        ops = []
95
96
        linear_idx = None
        bias_idx = None
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        linear_kwargs = {
            "in_features": in_features,
            "out_features": out_features,
            "device": device,
            "dtype": dtype,
            "tensor_parallel_mode": tensor_parallel_mode,
            "tensor_parallel_group": tensor_parallel_group,
            "sequence_parallel": sequence_parallel,
            "rng_state_tracker_function": rng_state_tracker_function,
            "accumulate_into_main_grad": accumulate_into_main_grad,
        }
        bias_kwargs = {
            "size": out_features,
            "device": device,
            "dtype": dtype,
            "tensor_parallel": (tensor_parallel_mode is not None),
            "tensor_parallel_group": tensor_parallel_group,
        }
115
116
        if tensor_parallel_mode == "row":
            # Row TP: GEMM + bias + reduction
117
            linear_idx = len(ops)
118
119
120
121
122
123
124
            linear_kwargs["in_features"] = local_in_features
            linear_kwargs["out_features"] = local_out_features
            linear_kwargs["tensor_parallel_mode"] = None
            linear_kwargs["tensor_parallel_group"] = None
            linear_kwargs["sequence_parallel"] = False
            ops.append(BasicLinear(**linear_kwargs))
            if bias:
125
126
                bias_idx = len(ops)
                bias_kwargs["size"] *= tensor_parallel_size
127
128
129
130
131
132
133
                ops.append(Bias(**bias_kwargs))
            if sequence_parallel:
                ops.append(ReduceScatter(tensor_parallel_group))
            else:
                ops.append(AllReduce(tensor_parallel_group))
        else:
            # Column TP or no TP: (gather + GEMM) + bias
134
            linear_idx = len(ops)
135
136
            ops.append(BasicLinear(**linear_kwargs))
            if bias:
137
                bias_idx = len(ops)
138
139
140
141
142
                ops.append(Bias(**bias_kwargs))

        # Initialize base class
        super().__init__(ops)

143
144
145
146
147
148
149
150
        # Register parameters
        self._linear_idx: Optional[int] = linear_idx
        self._bias_idx: Optional[int] = bias_idx
        self.register_parameter("weight", self.basic_ops[self._linear_idx].weight)
        bias = None
        if self._bias_idx is not None:
            bias = self.basic_ops[self._bias_idx].bias
        self.register_parameter("bias", bias)
151

152
153
    def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
        """Add a parameter to the module
154

155
        Also updates the basic operation that owns the parameter.
156
157

        """
158
        if name == "bias" and self._bias_idx is None and param is not None:
159
160
161
162
            raise ValueError(
                "Attempted to set bias parameter in Linear operation "
                "that does not have bias enabled"
            )
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        super().register_parameter(name, param)
        if name == "weight":
            self.basic_ops[self._linear_idx].weight = param
        elif name == "bias" and self._bias_idx is not None:
            self.basic_ops[self._bias_idx].bias = param

    def state_dict(self, *, prefix: str = "", **kwargs) -> dict[str, Any]:
        """Save state"""
        state_dict = super().state_dict(prefix=prefix, **kwargs)

        # Remove basic op params from state dict
        # Note: Logically, basic ops own params and fused ops are
        # considered as stateless. However, we register weight and
        # bias params in the linear op for convenience. We remove
        # these redudant params from the checkpoint for backward
        # compatibility.
        if f"{prefix}weight" in state_dict:
            del state_dict[f"{prefix}weight"]
        if f"{prefix}bias" in state_dict:
            del state_dict[f"{prefix}bias"]

        return state_dict

    def _load_from_state_dict(
        self,
        state_dict: dict[str, Any],
        prefix: str,
        *args,
        **kwargs,
    ) -> None:

        # Add basic op params to state dict
        # Note: Logically, basic ops own params and fused ops are
        # considered as stateless. However, we register weight and
        # bias params in the linear op for convenience. We remove
        # these redudant params from the checkpoint for backward
        # compatibility.
        if f"{prefix}weight" not in state_dict:
            state_dict[f"{prefix}weight"] = state_dict[
                f"{prefix}basic_ops.{self._linear_idx}.weight"
            ]
        if f"{prefix}bias" not in state_dict:
            if self._bias_idx is None:
                state_dict[f"{prefix}bias"] = None
            else:
                state_dict[f"{prefix}bias"] = state_dict[f"{prefix}basic_ops.{self._bias_idx}.bias"]

        # Load state dict
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)