layernorm.py 12.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
import vllm.envs as envs
10
from vllm.model_executor.custom_op import CustomOp
11
from vllm.platforms import current_platform
12
from vllm.utils import direct_register_custom_op
13
14
15


def is_rocm_aiter_rmsnorm_enabled() -> bool:
16
    return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER
17
18


19
20
21
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
22
    from vllm import _custom_ops as ops
23

24
25
26
27
28
29
30
31
32
33
34
    out = torch.empty_like(x)
    ops.rms_norm(
        out,
        x,
        weight,
        variance_epsilon,
    )
    return out


def fused_add_rms_norm(
35
36
37
38
39
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
40
    from vllm import _custom_ops as ops
41

42
43
44
45
46
47
48
49
50
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


51
52
53
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
54
    from vllm import _custom_ops as ops
55

56
57
58
59
60
61
62
63
64
65
66
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
        weight,
        bias,
        variance_epsilon,
    )
    return out


67
68
69
def rocm_aiter_rms_norm_impl(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
70
    import aiter as rocm_aiter
71

72
73
74
75
76
77
    if x.dim() > 2:
        x_original_shape = x.shape
        x = x.reshape(-1, x_original_shape[-1])
        x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
        return x.reshape(x_original_shape)

78
79
80
    return rocm_aiter.rms_norm(x, weight, variance_epsilon)


81
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
82
83
84
85
86
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
87
88
    import aiter as rocm_aiter

89
90
    residual_out = torch.empty_like(residual)
    output = torch.empty_like(x)
91
    rocm_aiter.rmsnorm2d_fwd_with_add(
92
        output,  # output
93
94
        x,  # input
        residual,  # residual input
95
        residual_out,  # residual output
96
97
98
        weight,
        variance_epsilon,
    )
99
    return output, residual_out
100
101


102
103
104
def rocm_aiter_rms_norm_fake(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
105
106
107
108
    return torch.empty_like(x)


def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
109
110
111
112
113
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    return torch.empty_like(x), torch.empty_like(residual)


if current_platform.is_rocm():
    direct_register_custom_op(
        op_name="rocm_aiter_rms_norm",
        op_func=rocm_aiter_rms_norm_impl,
        fake_impl=rocm_aiter_rms_norm_fake,
    )

    direct_register_custom_op(
        op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
        op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
        fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
    )


def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
    use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
133
134
        torch.float16,
        torch.bfloat16,
135
136
137
138
139
140
    ]

    if use_aiter and with_fused_add:
        return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
    if use_aiter:
        return torch.ops.vllm.rocm_aiter_rms_norm
141

142
143
144
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
145
    return rms_norm
146
147


148
@CustomOp.register("rms_norm")
149
class RMSNorm(CustomOp):
150
151
152
153
154
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
155
156
157
158
159

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
160
        var_hidden_size: int | None = None,
161
        has_weight: bool = True,
162
        dtype: torch.dtype | None = None,
163
164
    ) -> None:
        super().__init__()
165
166

        self.hidden_size = hidden_size
167
        self.variance_epsilon = eps
168
169
170
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
171
        self.has_weight = has_weight
172
173
174
175
        if dtype is not None:
            self.weight = torch.ones(hidden_size, dtype=dtype)
        else:
            self.weight = torch.ones(hidden_size)
176
177
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
178
179
180
181
        weight_dtype = self.weight.data.dtype

        if current_platform.is_rocm():
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
182
183
                with_fused_add=False, dtype=weight_dtype
            )
184
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
185
186
                with_fused_add=True, dtype=weight_dtype
            )
187

188
    def forward_native(
189
190
        self,
        x: torch.Tensor,
191
192
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
193
194
195
196
197
198
199
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

200
201
        hidden_size = x.shape[-1]
        if hidden_size != self.hidden_size:
202
203
204
205
            raise ValueError(
                "Expected hidden_size to be "
                f"{self.hidden_size}, but found: {hidden_size}"
            )
206
207
208
209
210
211
212

        if self.variance_size_override is None:
            x_var = x
        else:
            if hidden_size < self.variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
213
214
                    f"{self.variance_size_override}, but found: {hidden_size}"
                )
215

216
            x_var = x[:, :, : self.variance_size_override]
217
218
219

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

220
        x = x * torch.rsqrt(variance + self.variance_epsilon)
221
222
223
        x = x.to(orig_dtype)
        if self.has_weight:
            x = x * self.weight
224
225
226
227
228
        if residual is None:
            return x
        else:
            return x, residual

229
    def forward_cuda(
230
231
        self,
        x: torch.Tensor,
232
233
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
234
235
236
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

237
        add_residual = residual is not None
238
        if add_residual:
239
240
241
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
242
243
244
245
246
247
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)

    def forward_hip(
        self,
        x: torch.Tensor,
248
249
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
250
251
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
252

253
        add_residual = residual is not None
254
        if add_residual:
255
256
257
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
258
        else:
259
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
260

261
262
263
    def forward_xpu(
        self,
        x: torch.Tensor,
264
265
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
266
267
268
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

269
270
271
272
273
274
275
276
277
278
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
279
        return ops.rms_norm(
280
281
282
283
284
            x,
            self.weight.data,
            self.variance_epsilon,
        )

285
286
287
288
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
289
290


291
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

309
310
311
312
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
313
        x: torch.Tensor,
314
315
        residual: torch.Tensor | None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
316
317
318
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
319
320
321
322
323
            x = (
                x.float() + residual.float()
                if orig_dtype == torch.float16
                else x + residual
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
324
325
326
327
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
328
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
329
330
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
331
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
332
333
334
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

335
336
337
    def forward_native(
        self,
        x: torch.Tensor,
338
339
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
340
        """PyTorch-native implementation equivalent to forward()."""
341
        return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)
342

Woosuk Kwon's avatar
Woosuk Kwon committed
343
344
345
    def forward_cuda(
        self,
        x: torch.Tensor,
346
347
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
348
349
350
351
352
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
353
354
                self.forward_static
            )
355
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
356
        return self.forward_native(x, residual)
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377


@CustomOp.register("poly_norm")
class PolyNorm(CustomOp):
    """Polynomial normalization.

    Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b
    where w_n is the learned weight and b is the bias.
    Refer to https://arxiv.org/html/2411.03884v1
    """

    def __init__(
        self,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(3) / 3)
        self.bias = torch.nn.Parameter(torch.zeros(1))
        self.variance_epsilon = eps

    def _norm(self, x):
378
        return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
379
380
381
382
383
384
385
386
387
388
389
390

    def forward_native(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward().

        Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
        """

        orig_dtype = x.dtype
        x_float = x.to(torch.float32)
391
392
393
394
395
396
        output = (
            self.weight[0] * self._norm(x_float**3)
            + self.weight[1] * self._norm(x_float**2)
            + self.weight[2] * self._norm(x_float)
            + self.bias
        )
397
398
399
400
401
402
403
        return output.to(orig_dtype)

    def forward_cuda(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        return poly_norm(x, self.weight, self.bias, self.variance_epsilon)
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418


class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
419
420
421
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)