"vscode:/vscode.git/clone" did not exist on "30870b4f66414020645608b81dced94d8a99111c"
layernorm.py 10 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
zhuwenwen's avatar
zhuwenwen committed
4
from typing import Optional, Union, Tuple
5

6
7
8
import torch
import torch.nn as nn

zhuwenwen's avatar
zhuwenwen committed
9
import vllm.envs as envs
10
from vllm.model_executor.custom_op import CustomOp
zhuwenwen's avatar
zhuwenwen committed
11

12
13
14
15
16
17
18
19
20
21
22
23
24
from vllm.platforms import current_platform


def is_rocm_aiter_rmsnorm_enabled() -> bool:
    return current_platform.is_rocm() \
        and envs.VLLM_ROCM_USE_AITER_RMSNORM \
        and envs.VLLM_ROCM_USE_AITER


def rms_norm(x: torch.Tensor, weight: torch.Tensor,
             variance_epsilon: float) -> torch.Tensor:
    from vllm import _custom_ops as ops
    out = torch.empty_like(x)
zhuwenwen's avatar
zhuwenwen committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    if envs.VLLM_USE_OPT_OP:
        ops.rms_norm_opt(
            out,
            x,
            weight,
            variance_epsilon,
        )
    else:
        ops.rms_norm(
            out,
            x,
            weight,
            variance_epsilon,
        )
39
40
41
42
43
    return out


def fused_add_rms_norm(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
44
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
45
    from vllm import _custom_ops as ops
zhuwenwen's avatar
zhuwenwen committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    if envs.VLLM_USE_OPT_OP:
        ops.fused_add_rms_norm_opt(
            x,
            residual,
            weight,
            variance_epsilon,
        )
    else:
        ops.fused_add_rms_norm(
            x,
            residual,
            weight,
            variance_epsilon,
        )
60
61
62
63
64
65
66
    return x, residual


def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
                        variance_epsilon: float) -> torch.Tensor:

    import aiter as rocm_aiter
67
68
69
70
71
72
    if x.dim() > 2:
        x_original_shape = x.shape
        x = x.reshape(-1, x_original_shape[-1])
        x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
        return x.reshape(x_original_shape)

73
74
75
76
77
    return rocm_aiter.rms_norm(x, weight, variance_epsilon)


def rocm_aiter_fused_add_rms_norm(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
78
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
79
80
81

    import aiter as rocm_aiter

82
83
    residual_out = torch.empty_like(residual)
    output = torch.empty_like(x)
84
    rocm_aiter.rmsnorm2d_fwd_with_add(
85
        output,  # output
86
87
        x,  # input
        residual,  # residual input
88
        residual_out,  # residual output
89
90
91
        weight,
        variance_epsilon,
    )
92
    return output, residual_out
93
94
95
96
97
98
99
100
101
102
103


def dispatch_cuda_rmsnorm_func(add_residual: bool):
    if add_residual:
        if is_rocm_aiter_rmsnorm_enabled():
            return rocm_aiter_fused_add_rms_norm
        return fused_add_rms_norm

    if is_rocm_aiter_rmsnorm_enabled():
        return rocm_aiter_rms_norm
    return rms_norm
104
105


106
@CustomOp.register("rms_norm")
107
class RMSNorm(CustomOp):
108
109
110
111
112
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
113
114
115
116
117

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
118
        var_hidden_size: Optional[int] = None,
119
        has_weight: bool = True,
120
        dtype: Optional[torch.dtype] = None,
121
122
    ) -> None:
        super().__init__()
123
124

        self.hidden_size = hidden_size
125
        self.variance_epsilon = eps
126
127
        self.variance_size_override = (None if var_hidden_size == hidden_size
                                       else var_hidden_size)
128
        self.has_weight = has_weight
129
130
131
132
        if dtype is not None:
            self.weight = torch.ones(hidden_size, dtype=dtype)
        else:
            self.weight = torch.ones(hidden_size)
133
134
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
135

136
    def forward_native(
137
138
139
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
140
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
141
142
143
144
145
146
147
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        hidden_size = x.shape[-1]
        if hidden_size != self.hidden_size:
            raise ValueError("Expected hidden_size to be "
                             f"{self.hidden_size}, but found: {hidden_size}")

        if self.variance_size_override is None:
            x_var = x
        else:
            if hidden_size < self.variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
                    f"{self.variance_size_override}, but found: {hidden_size}")

            x_var = x[:, :, :self.variance_size_override]

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

165
        x = x * torch.rsqrt(variance + self.variance_epsilon)
166
167
168
        x = x.to(orig_dtype)
        if self.has_weight:
            x = x * self.weight
169
170
171
172
173
        if residual is None:
            return x
        else:
            return x, residual

174
    def forward_cuda(
175
176
177
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
178
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
179
180
181
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

182
183
        add_residual = residual is not None
        norm_func = dispatch_cuda_rmsnorm_func(add_residual)
184

185
186
187
        if add_residual:
            return norm_func(x, residual, self.weight.data,
                             self.variance_epsilon)
zhuwenwen's avatar
zhuwenwen committed
188
        else:
189
            return norm_func(x, self.weight.data, self.variance_epsilon)
zhuwenwen's avatar
zhuwenwen committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        
    def forward_apex(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        from apex.normalization.fused_layer_norm import fused_rms_norm_affine
        add_residual = residual is not None
        norm_func = dispatch_cuda_rmsnorm_func(add_residual)

        if add_residual:
            return norm_func(x, residual, self.weight.data,
                             self.variance_epsilon)
        else:
            return fused_rms_norm_affine(x, self.weight.data, torch.Size((x.shape[-1],)), self.variance_epsilon)
205

206
207
208
209
    def forward_hpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
210
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
211
212
        from vllm_hpu_extension.kernels import rms_norm
        HPUFusedRMSNorm = rms_norm()
213
214
215
216
217
218
219
220
221
222
223
224
225
        if HPUFusedRMSNorm is None:
            return self.forward_native(x, residual)
        if residual is not None:
            orig_shape = x.shape
            residual += x.view(residual.shape)
            # Note: HPUFusedRMSNorm requires 3D tensors as inputs
            x = HPUFusedRMSNorm.apply(residual, self.weight,
                                      self.variance_epsilon)
            return x.view(orig_shape), residual

        x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
        return x

226
227
228
229
    def forward_xpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
230
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
231
232
233
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

234
235
236
237
238
239
240
241
242
243
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
244
        return ops.rms_norm(
245
246
247
248
249
            x,
            self.weight.data,
            self.variance_epsilon,
        )

250
251
252
253
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
254
255


256
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

274
275
276
277
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
278
        x: torch.Tensor,
279
        residual: Optional[torch.Tensor],
280
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
281
282
283
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
284
285
286
287
            if orig_dtype == torch.float16:
                x = x + residual.float()
            else:
                x = x + residual
Woosuk Kwon's avatar
Woosuk Kwon committed
288
289
290
291
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
292
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
293
294
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
295
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
296
297
298
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

299
300
301
302
    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
303
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
304
305
306
307
        """PyTorch-native implementation equivalent to forward()."""
        return self.forward_static(self.weight.data, self.variance_epsilon, x,
                                   residual)

Woosuk Kwon's avatar
Woosuk Kwon committed
308
309
310
311
    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
312
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
313
314
315
316
317
318
319
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
                self.forward_static)
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
320
        return self.forward_native(x, residual)