layernorm.py 12.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
import vllm.envs as envs
10
from vllm.model_executor.custom_op import CustomOp
11
12
13
14
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
    vllm_kernel_override_batch_invariant,
)
15
from vllm.platforms import current_platform
16
from vllm.utils import direct_register_custom_op
17
18
19


def is_rocm_aiter_rmsnorm_enabled() -> bool:
20
    return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER
21
22


23
24
25
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
26
    from vllm import _custom_ops as ops
27

28
29
    if vllm_kernel_override_batch_invariant():
        return rms_norm_batch_invariant(x, weight, variance_epsilon)
30
31
32
33
34
35
36
37
38
39
40
    out = torch.empty_like(x)
    ops.rms_norm(
        out,
        x,
        weight,
        variance_epsilon,
    )
    return out


def fused_add_rms_norm(
41
42
43
44
45
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
46
    from vllm import _custom_ops as ops
47

48
49
50
51
    if vllm_kernel_override_batch_invariant():
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
52
53
54
55
56
57
58
59
60
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


61
62
63
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
64
    from vllm import _custom_ops as ops
65

66
67
68
69
70
71
72
73
74
75
76
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
        weight,
        bias,
        variance_epsilon,
    )
    return out


77
78
79
def rocm_aiter_rms_norm_impl(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
80
    import aiter as rocm_aiter
81

82
83
84
85
86
87
    if x.dim() > 2:
        x_original_shape = x.shape
        x = x.reshape(-1, x_original_shape[-1])
        x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
        return x.reshape(x_original_shape)

88
89
90
    return rocm_aiter.rms_norm(x, weight, variance_epsilon)


91
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
92
93
94
95
96
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
97
98
    import aiter as rocm_aiter

99
100
    residual_out = torch.empty_like(residual)
    output = torch.empty_like(x)
101
    rocm_aiter.rmsnorm2d_fwd_with_add(
102
        output,  # output
103
104
        x,  # input
        residual,  # residual input
105
        residual_out,  # residual output
106
107
108
        weight,
        variance_epsilon,
    )
109
    return output, residual_out
110
111


112
113
114
def rocm_aiter_rms_norm_fake(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
115
116
117
118
    return torch.empty_like(x)


def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
119
120
121
122
123
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    return torch.empty_like(x), torch.empty_like(residual)


if current_platform.is_rocm():
    direct_register_custom_op(
        op_name="rocm_aiter_rms_norm",
        op_func=rocm_aiter_rms_norm_impl,
        fake_impl=rocm_aiter_rms_norm_fake,
    )

    direct_register_custom_op(
        op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
        op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
        fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
    )


def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
    use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
143
144
        torch.float16,
        torch.bfloat16,
145
146
147
148
149
150
    ]

    if use_aiter and with_fused_add:
        return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
    if use_aiter:
        return torch.ops.vllm.rocm_aiter_rms_norm
151

152
153
154
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
155
    return rms_norm
156
157


158
@CustomOp.register("rms_norm")
159
class RMSNorm(CustomOp):
160
161
162
163
164
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
165
166
167
168
169

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
170
        var_hidden_size: int | None = None,
171
        has_weight: bool = True,
172
        dtype: torch.dtype | None = None,
173
174
    ) -> None:
        super().__init__()
175
176

        self.hidden_size = hidden_size
177
        self.variance_epsilon = eps
178
179
180
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
181
        self.has_weight = has_weight
182
183
184
185
        if dtype is not None:
            self.weight = torch.ones(hidden_size, dtype=dtype)
        else:
            self.weight = torch.ones(hidden_size)
186
187
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
188
189
190
191
        weight_dtype = self.weight.data.dtype

        if current_platform.is_rocm():
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
192
193
                with_fused_add=False, dtype=weight_dtype
            )
194
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
195
196
                with_fused_add=True, dtype=weight_dtype
            )
197

198
    def forward_native(
199
200
        self,
        x: torch.Tensor,
201
202
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
203
204
205
206
207
208
209
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

210
211
        hidden_size = x.shape[-1]
        if hidden_size != self.hidden_size:
212
213
214
215
            raise ValueError(
                "Expected hidden_size to be "
                f"{self.hidden_size}, but found: {hidden_size}"
            )
216
217
218
219
220
221
222

        if self.variance_size_override is None:
            x_var = x
        else:
            if hidden_size < self.variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
223
224
                    f"{self.variance_size_override}, but found: {hidden_size}"
                )
225

226
            x_var = x[:, :, : self.variance_size_override]
227
228
229

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

230
        x = x * torch.rsqrt(variance + self.variance_epsilon)
231
232
233
        x = x.to(orig_dtype)
        if self.has_weight:
            x = x * self.weight
234
235
236
237
238
        if residual is None:
            return x
        else:
            return x, residual

239
    def forward_cuda(
240
241
        self,
        x: torch.Tensor,
242
243
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
244
245
246
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

247
        add_residual = residual is not None
248
        if add_residual:
249
250
251
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
252
253
254
255
256
257
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)

    def forward_hip(
        self,
        x: torch.Tensor,
258
259
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
260
261
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
262

263
        add_residual = residual is not None
264
        if add_residual:
265
266
267
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
268
        else:
269
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
270

271
272
273
    def forward_xpu(
        self,
        x: torch.Tensor,
274
275
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
276
277
278
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

279
280
281
282
283
284
285
286
287
288
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
289
        return ops.rms_norm(
290
291
292
293
294
            x,
            self.weight.data,
            self.variance_epsilon,
        )

295
296
297
298
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
299
300


301
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

319
320
321
322
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
323
        x: torch.Tensor,
324
325
        residual: torch.Tensor | None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
326
327
328
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
329
330
331
332
333
            x = (
                x.float() + residual.float()
                if orig_dtype == torch.float16
                else x + residual
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
334
335
336
337
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
338
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
339
340
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
341
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
342
343
344
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

345
346
347
    def forward_native(
        self,
        x: torch.Tensor,
348
349
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
350
        """PyTorch-native implementation equivalent to forward()."""
351
        return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)
352

Woosuk Kwon's avatar
Woosuk Kwon committed
353
354
355
    def forward_cuda(
        self,
        x: torch.Tensor,
356
357
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
358
359
360
361
362
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
363
364
                self.forward_static
            )
365
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
366
        return self.forward_native(x, residual)
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387


@CustomOp.register("poly_norm")
class PolyNorm(CustomOp):
    """Polynomial normalization.

    Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b
    where w_n is the learned weight and b is the bias.
    Refer to https://arxiv.org/html/2411.03884v1
    """

    def __init__(
        self,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(3) / 3)
        self.bias = torch.nn.Parameter(torch.zeros(1))
        self.variance_epsilon = eps

    def _norm(self, x):
388
        return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
389
390
391
392
393
394
395
396
397
398
399
400

    def forward_native(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward().

        Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
        """

        orig_dtype = x.dtype
        x_float = x.to(torch.float32)
401
402
403
404
405
406
        output = (
            self.weight[0] * self._norm(x_float**3)
            + self.weight[1] * self._norm(x_float**2)
            + self.weight[2] * self._norm(x_float)
            + self.bias
        )
407
408
409
410
411
412
413
        return output.to(orig_dtype)

    def forward_cuda(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        return poly_norm(x, self.weight, self.bias, self.variance_epsilon)
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428


class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
429
430
431
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)