layernorm.py 11.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4
from typing import Optional, Union
5

6
7
8
import torch
import torch.nn as nn

9
import vllm.envs as envs
10
from vllm.model_executor.custom_op import CustomOp
11
from vllm.platforms import current_platform
12
from vllm.utils import direct_register_custom_op
13
14
15


def is_rocm_aiter_rmsnorm_enabled() -> bool:
16
    return envs.VLLM_ROCM_USE_AITER_RMSNORM \
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
        and envs.VLLM_ROCM_USE_AITER


def rms_norm(x: torch.Tensor, weight: torch.Tensor,
             variance_epsilon: float) -> torch.Tensor:
    from vllm import _custom_ops as ops
    out = torch.empty_like(x)
    ops.rms_norm(
        out,
        x,
        weight,
        variance_epsilon,
    )
    return out


def fused_add_rms_norm(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
35
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
36
37
38
39
40
41
42
43
44
45
    from vllm import _custom_ops as ops
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


46
47
48
49
50
51
52
53
54
55
56
57
58
59
def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
              variance_epsilon: float) -> torch.Tensor:
    from vllm import _custom_ops as ops
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
        weight,
        bias,
        variance_epsilon,
    )
    return out


60
61
def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor,
                             variance_epsilon: float) -> torch.Tensor:
62
    import aiter as rocm_aiter
63
64
65
66
67
68
    if x.dim() > 2:
        x_original_shape = x.shape
        x = x.reshape(-1, x_original_shape[-1])
        x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
        return x.reshape(x_original_shape)

69
70
71
    return rocm_aiter.rms_norm(x, weight, variance_epsilon)


72
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
73
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
74
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
75
76
77

    import aiter as rocm_aiter

78
79
    residual_out = torch.empty_like(residual)
    output = torch.empty_like(x)
80
    rocm_aiter.rmsnorm2d_fwd_with_add(
81
        output,  # output
82
83
        x,  # input
        residual,  # residual input
84
        residual_out,  # residual output
85
86
87
        weight,
        variance_epsilon,
    )
88
    return output, residual_out
89
90


91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def rocm_aiter_rms_norm_fake(x: torch.Tensor, weight: torch.Tensor,
                             variance_epsilon: float) -> torch.Tensor:
    return torch.empty_like(x)


def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
    return torch.empty_like(x), torch.empty_like(residual)


if current_platform.is_rocm():
    direct_register_custom_op(
        op_name="rocm_aiter_rms_norm",
        op_func=rocm_aiter_rms_norm_impl,
        fake_impl=rocm_aiter_rms_norm_fake,
    )

    direct_register_custom_op(
        op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
        op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
        fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
    )


def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
    use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
        torch.float16, torch.bfloat16
    ]

    if use_aiter and with_fused_add:
        return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
    if use_aiter:
        return torch.ops.vllm.rocm_aiter_rms_norm
125

126
127
128
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
129
    return rms_norm
130
131


132
@CustomOp.register("rms_norm")
133
class RMSNorm(CustomOp):
134
135
136
137
138
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
139
140
141
142
143

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
144
        var_hidden_size: Optional[int] = None,
145
        has_weight: bool = True,
146
        dtype: Optional[torch.dtype] = None,
147
148
    ) -> None:
        super().__init__()
149
150

        self.hidden_size = hidden_size
151
        self.variance_epsilon = eps
152
153
        self.variance_size_override = (None if var_hidden_size == hidden_size
                                       else var_hidden_size)
154
        self.has_weight = has_weight
155
156
157
158
        if dtype is not None:
            self.weight = torch.ones(hidden_size, dtype=dtype)
        else:
            self.weight = torch.ones(hidden_size)
159
160
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
161
162
163
164
165
166
167
        weight_dtype = self.weight.data.dtype

        if current_platform.is_rocm():
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
                with_fused_add=False, dtype=weight_dtype)
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
                with_fused_add=True, dtype=weight_dtype)
168

169
    def forward_native(
170
171
172
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
173
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
174
175
176
177
178
179
180
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        hidden_size = x.shape[-1]
        if hidden_size != self.hidden_size:
            raise ValueError("Expected hidden_size to be "
                             f"{self.hidden_size}, but found: {hidden_size}")

        if self.variance_size_override is None:
            x_var = x
        else:
            if hidden_size < self.variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
                    f"{self.variance_size_override}, but found: {hidden_size}")

            x_var = x[:, :, :self.variance_size_override]

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

198
        x = x * torch.rsqrt(variance + self.variance_epsilon)
199
200
201
        x = x.to(orig_dtype)
        if self.has_weight:
            x = x * self.weight
202
203
204
205
206
        if residual is None:
            return x
        else:
            return x, residual

207
    def forward_cuda(
208
209
210
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
211
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
212
213
214
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

215
        add_residual = residual is not None
216
217
218
219
220
221
222
223
224
225
226
227
228
        if add_residual:
            return fused_add_rms_norm(x, residual, self.weight.data,
                                      self.variance_epsilon)
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)

    def forward_hip(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
229

230
        add_residual = residual is not None
231
        if add_residual:
232
233
            return self.rocm_norm_func_with_add(x, residual, self.weight.data,
                                                self.variance_epsilon)
234
        else:
235
236
            return self.rocm_norm_func(x, self.weight.data,
                                       self.variance_epsilon)
237

238
239
240
241
    def forward_xpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
242
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
243
244
245
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

246
247
248
249
250
251
252
253
254
255
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
256
        return ops.rms_norm(
257
258
259
260
261
            x,
            self.weight.data,
            self.variance_epsilon,
        )

262
263
264
265
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
266
267


268
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

286
287
288
289
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
290
        x: torch.Tensor,
291
        residual: Optional[torch.Tensor],
292
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
293
294
295
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
296
297
298
299
            if orig_dtype == torch.float16:
                x = x + residual.float()
            else:
                x = x + residual
Woosuk Kwon's avatar
Woosuk Kwon committed
300
301
302
303
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
304
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
305
306
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
307
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
308
309
310
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

311
312
313
314
    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
315
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
316
317
318
319
        """PyTorch-native implementation equivalent to forward()."""
        return self.forward_static(self.weight.data, self.variance_epsilon, x,
                                   residual)

Woosuk Kwon's avatar
Woosuk Kwon committed
320
321
322
323
    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
324
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
325
326
327
328
329
330
331
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
                self.forward_static)
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
332
        return self.forward_native(x, residual)
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377


@CustomOp.register("poly_norm")
class PolyNorm(CustomOp):
    """Polynomial normalization.

    Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b
    where w_n is the learned weight and b is the bias.
    Refer to https://arxiv.org/html/2411.03884v1
    """

    def __init__(
        self,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(3) / 3)
        self.bias = torch.nn.Parameter(torch.zeros(1))
        self.variance_epsilon = eps

    def _norm(self, x):
        return x / torch.sqrt(
            x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)

    def forward_native(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward().

        Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
        """

        orig_dtype = x.dtype
        x_float = x.to(torch.float32)
        output = (self.weight[0] * self._norm(x_float**3) +
                  self.weight[1] * self._norm(x_float**2) +
                  self.weight[2] * self._norm(x_float) + self.bias)
        return output.to(orig_dtype)

    def forward_cuda(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        return poly_norm(x, self.weight, self.bias, self.variance_epsilon)