layernorm.py 12.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
from typing import Optional, Union
6

7
8
import torch
import torch.nn as nn
9
import torch.nn.functional as F
10

11
import vllm.envs as envs
12
from vllm.model_executor.custom_op import CustomOp
13
from vllm.platforms import current_platform
14
from vllm.utils import direct_register_custom_op
15
16
17


def is_rocm_aiter_rmsnorm_enabled() -> bool:
18
    return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER
19
20


21
22
23
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
24
    from vllm import _custom_ops as ops
25

26
27
28
29
30
31
32
33
34
35
36
    out = torch.empty_like(x)
    ops.rms_norm(
        out,
        x,
        weight,
        variance_epsilon,
    )
    return out


def fused_add_rms_norm(
37
38
39
40
41
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
42
    from vllm import _custom_ops as ops
43

44
45
46
47
48
49
50
51
52
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


53
54
55
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
56
    from vllm import _custom_ops as ops
57

58
59
60
61
62
63
64
65
66
67
68
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
        weight,
        bias,
        variance_epsilon,
    )
    return out


69
70
71
def rocm_aiter_rms_norm_impl(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
72
    import aiter as rocm_aiter
73

74
75
76
77
78
79
    if x.dim() > 2:
        x_original_shape = x.shape
        x = x.reshape(-1, x_original_shape[-1])
        x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
        return x.reshape(x_original_shape)

80
81
82
    return rocm_aiter.rms_norm(x, weight, variance_epsilon)


83
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
84
85
86
87
88
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
89
90
    import aiter as rocm_aiter

91
92
    residual_out = torch.empty_like(residual)
    output = torch.empty_like(x)
93
    rocm_aiter.rmsnorm2d_fwd_with_add(
94
        output,  # output
95
96
        x,  # input
        residual,  # residual input
97
        residual_out,  # residual output
98
99
100
        weight,
        variance_epsilon,
    )
101
    return output, residual_out
102
103


104
105
106
def rocm_aiter_rms_norm_fake(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
107
108
109
110
    return torch.empty_like(x)


def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
111
112
113
114
115
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    return torch.empty_like(x), torch.empty_like(residual)


if current_platform.is_rocm():
    direct_register_custom_op(
        op_name="rocm_aiter_rms_norm",
        op_func=rocm_aiter_rms_norm_impl,
        fake_impl=rocm_aiter_rms_norm_fake,
    )

    direct_register_custom_op(
        op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
        op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
        fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
    )


def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
    use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
135
136
        torch.float16,
        torch.bfloat16,
137
138
139
140
141
142
    ]

    if use_aiter and with_fused_add:
        return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
    if use_aiter:
        return torch.ops.vllm.rocm_aiter_rms_norm
143

144
145
146
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
147
    return rms_norm
148
149


150
@CustomOp.register("rms_norm")
151
class RMSNorm(CustomOp):
152
153
154
155
156
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
157
158
159
160
161

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
162
        var_hidden_size: Optional[int] = None,
163
        has_weight: bool = True,
164
        dtype: Optional[torch.dtype] = None,
165
166
    ) -> None:
        super().__init__()
167
168

        self.hidden_size = hidden_size
169
        self.variance_epsilon = eps
170
171
172
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
173
        self.has_weight = has_weight
174
175
176
177
        if dtype is not None:
            self.weight = torch.ones(hidden_size, dtype=dtype)
        else:
            self.weight = torch.ones(hidden_size)
178
179
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
180
181
182
183
        weight_dtype = self.weight.data.dtype

        if current_platform.is_rocm():
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
184
185
                with_fused_add=False, dtype=weight_dtype
            )
186
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
187
188
                with_fused_add=True, dtype=weight_dtype
            )
189

190
    def forward_native(
191
192
193
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
194
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
195
196
197
198
199
200
201
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

202
203
        hidden_size = x.shape[-1]
        if hidden_size != self.hidden_size:
204
205
206
207
            raise ValueError(
                "Expected hidden_size to be "
                f"{self.hidden_size}, but found: {hidden_size}"
            )
208
209
210
211
212
213
214

        if self.variance_size_override is None:
            x_var = x
        else:
            if hidden_size < self.variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
215
216
                    f"{self.variance_size_override}, but found: {hidden_size}"
                )
217

218
            x_var = x[:, :, : self.variance_size_override]
219
220
221

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

222
        x = x * torch.rsqrt(variance + self.variance_epsilon)
223
224
225
        x = x.to(orig_dtype)
        if self.has_weight:
            x = x * self.weight
226
227
228
229
230
        if residual is None:
            return x
        else:
            return x, residual

231
    def forward_cuda(
232
233
234
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
235
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
236
237
238
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

239
        add_residual = residual is not None
240
        if add_residual:
241
242
243
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
244
245
246
247
248
249
250
251
252
253
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)

    def forward_hip(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
254

255
        add_residual = residual is not None
256
        if add_residual:
257
258
259
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
260
        else:
261
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
262

263
264
265
266
    def forward_xpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
267
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
268
269
270
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

271
272
273
274
275
276
277
278
279
280
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
281
        return ops.rms_norm(
282
283
284
285
286
            x,
            self.weight.data,
            self.variance_epsilon,
        )

287
288
289
290
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
291
292


293
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

311
312
313
314
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
315
        x: torch.Tensor,
316
        residual: Optional[torch.Tensor],
317
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
318
319
320
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
321
322
323
324
325
            x = (
                x.float() + residual.float()
                if orig_dtype == torch.float16
                else x + residual
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
326
327
328
329
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
330
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
331
332
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
333
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
334
335
336
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

337
338
339
340
    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
341
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
342
        """PyTorch-native implementation equivalent to forward()."""
343
        return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)
344

Woosuk Kwon's avatar
Woosuk Kwon committed
345
346
347
348
    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
349
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
350
351
352
353
354
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
355
356
                self.forward_static
            )
357
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
358
        return self.forward_native(x, residual)
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379


@CustomOp.register("poly_norm")
class PolyNorm(CustomOp):
    """Polynomial normalization.

    Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b
    where w_n is the learned weight and b is the bias.
    Refer to https://arxiv.org/html/2411.03884v1
    """

    def __init__(
        self,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(3) / 3)
        self.bias = torch.nn.Parameter(torch.zeros(1))
        self.variance_epsilon = eps

    def _norm(self, x):
380
        return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
381
382
383
384
385
386
387
388
389
390
391
392

    def forward_native(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward().

        Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
        """

        orig_dtype = x.dtype
        x_float = x.to(torch.float32)
393
394
395
396
397
398
        output = (
            self.weight[0] * self._norm(x_float**3)
            + self.weight[1] * self._norm(x_float**2)
            + self.weight[2] * self._norm(x_float)
            + self.bias
        )
399
400
401
402
403
404
405
        return output.to(orig_dtype)

    def forward_cuda(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        return poly_norm(x, self.weight, self.bias, self.variance_epsilon)
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420


class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
421
422
423
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)