layernorm.py 12.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4
from typing import Optional, Union
5

6
7
import torch
import torch.nn as nn
8
import torch.nn.functional as F
9

10
import vllm.envs as envs
11
from vllm.model_executor.custom_op import CustomOp
12
from vllm.platforms import current_platform
13
from vllm.utils import direct_register_custom_op
14
15
16


def is_rocm_aiter_rmsnorm_enabled() -> bool:
17
    return envs.VLLM_ROCM_USE_AITER_RMSNORM \
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
        and envs.VLLM_ROCM_USE_AITER


def rms_norm(x: torch.Tensor, weight: torch.Tensor,
             variance_epsilon: float) -> torch.Tensor:
    from vllm import _custom_ops as ops
    out = torch.empty_like(x)
    ops.rms_norm(
        out,
        x,
        weight,
        variance_epsilon,
    )
    return out


def fused_add_rms_norm(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
36
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
37
38
39
40
41
42
43
44
45
46
    from vllm import _custom_ops as ops
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


47
48
49
50
51
52
53
54
55
56
57
58
59
60
def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
              variance_epsilon: float) -> torch.Tensor:
    from vllm import _custom_ops as ops
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
        weight,
        bias,
        variance_epsilon,
    )
    return out


61
62
def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor,
                             variance_epsilon: float) -> torch.Tensor:
63
    import aiter as rocm_aiter
64
65
66
67
68
69
    if x.dim() > 2:
        x_original_shape = x.shape
        x = x.reshape(-1, x_original_shape[-1])
        x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
        return x.reshape(x_original_shape)

70
71
72
    return rocm_aiter.rms_norm(x, weight, variance_epsilon)


73
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
74
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
75
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
76
77
78

    import aiter as rocm_aiter

79
80
    residual_out = torch.empty_like(residual)
    output = torch.empty_like(x)
81
    rocm_aiter.rmsnorm2d_fwd_with_add(
82
        output,  # output
83
84
        x,  # input
        residual,  # residual input
85
        residual_out,  # residual output
86
87
88
        weight,
        variance_epsilon,
    )
89
    return output, residual_out
90
91


92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def rocm_aiter_rms_norm_fake(x: torch.Tensor, weight: torch.Tensor,
                             variance_epsilon: float) -> torch.Tensor:
    return torch.empty_like(x)


def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
    return torch.empty_like(x), torch.empty_like(residual)


if current_platform.is_rocm():
    direct_register_custom_op(
        op_name="rocm_aiter_rms_norm",
        op_func=rocm_aiter_rms_norm_impl,
        fake_impl=rocm_aiter_rms_norm_fake,
    )

    direct_register_custom_op(
        op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
        op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
        fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
    )


def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
    use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
        torch.float16, torch.bfloat16
    ]

    if use_aiter and with_fused_add:
        return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
    if use_aiter:
        return torch.ops.vllm.rocm_aiter_rms_norm
126

127
128
129
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
130
    return rms_norm
131
132


133
@CustomOp.register("rms_norm")
134
class RMSNorm(CustomOp):
135
136
137
138
139
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
140
141
142
143
144

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
145
        var_hidden_size: Optional[int] = None,
146
        has_weight: bool = True,
147
        dtype: Optional[torch.dtype] = None,
148
149
    ) -> None:
        super().__init__()
150
151

        self.hidden_size = hidden_size
152
        self.variance_epsilon = eps
153
154
        self.variance_size_override = (None if var_hidden_size == hidden_size
                                       else var_hidden_size)
155
        self.has_weight = has_weight
156
157
158
159
        if dtype is not None:
            self.weight = torch.ones(hidden_size, dtype=dtype)
        else:
            self.weight = torch.ones(hidden_size)
160
161
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
162
163
164
165
166
167
168
        weight_dtype = self.weight.data.dtype

        if current_platform.is_rocm():
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
                with_fused_add=False, dtype=weight_dtype)
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
                with_fused_add=True, dtype=weight_dtype)
169

170
    def forward_native(
171
172
173
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
174
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
175
176
177
178
179
180
181
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        hidden_size = x.shape[-1]
        if hidden_size != self.hidden_size:
            raise ValueError("Expected hidden_size to be "
                             f"{self.hidden_size}, but found: {hidden_size}")

        if self.variance_size_override is None:
            x_var = x
        else:
            if hidden_size < self.variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
                    f"{self.variance_size_override}, but found: {hidden_size}")

            x_var = x[:, :, :self.variance_size_override]

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

199
        x = x * torch.rsqrt(variance + self.variance_epsilon)
200
201
202
        x = x.to(orig_dtype)
        if self.has_weight:
            x = x * self.weight
203
204
205
206
207
        if residual is None:
            return x
        else:
            return x, residual

208
    def forward_cuda(
209
210
211
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
212
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
213
214
215
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

216
        add_residual = residual is not None
217
218
219
220
221
222
223
224
225
226
227
228
229
        if add_residual:
            return fused_add_rms_norm(x, residual, self.weight.data,
                                      self.variance_epsilon)
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)

    def forward_hip(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
230

231
        add_residual = residual is not None
232
        if add_residual:
233
234
            return self.rocm_norm_func_with_add(x, residual, self.weight.data,
                                                self.variance_epsilon)
235
        else:
236
237
            return self.rocm_norm_func(x, self.weight.data,
                                       self.variance_epsilon)
238

239
240
241
242
    def forward_xpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
243
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
244
245
246
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

247
248
249
250
251
252
253
254
255
256
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
257
        return ops.rms_norm(
258
259
260
261
262
            x,
            self.weight.data,
            self.variance_epsilon,
        )

263
264
265
266
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
267
268


269
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

287
288
289
290
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
291
        x: torch.Tensor,
292
        residual: Optional[torch.Tensor],
293
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
294
295
296
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
297
298
299
300
            if orig_dtype == torch.float16:
                x = x + residual.float()
            else:
                x = x + residual
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
303
304
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
305
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
306
307
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
308
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
309
310
311
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

312
313
314
315
    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
316
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
317
318
319
320
        """PyTorch-native implementation equivalent to forward()."""
        return self.forward_static(self.weight.data, self.variance_epsilon, x,
                                   residual)

Woosuk Kwon's avatar
Woosuk Kwon committed
321
322
323
324
    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
325
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
326
327
328
329
330
331
332
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
                self.forward_static)
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
333
        return self.forward_native(x, residual)
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378


@CustomOp.register("poly_norm")
class PolyNorm(CustomOp):
    """Polynomial normalization.

    Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b
    where w_n is the learned weight and b is the bias.
    Refer to https://arxiv.org/html/2411.03884v1
    """

    def __init__(
        self,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(3) / 3)
        self.bias = torch.nn.Parameter(torch.zeros(1))
        self.variance_epsilon = eps

    def _norm(self, x):
        return x / torch.sqrt(
            x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)

    def forward_native(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward().

        Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
        """

        orig_dtype = x.dtype
        x_float = x.to(torch.float32)
        output = (self.weight[0] * self._norm(x_float**3) +
                  self.weight[1] * self._norm(x_float**2) +
                  self.weight[2] * self._norm(x_float) + self.bias)
        return output.to(orig_dtype)

    def forward_cuda(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        return poly_norm(x, self.weight, self.bias, self.variance_epsilon)
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395


class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
        return F.layer_norm(x.float(), (self.dim, ), self.weight, self.bias,
                            self.eps).type_as(x)