layernorm.py 13.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
zhuwenwen's avatar
zhuwenwen committed
4
from typing import Optional, Union, Tuple
5

6
7
import torch
import torch.nn as nn
8
import torch.nn.functional as F
9

zhuwenwen's avatar
zhuwenwen committed
10
import vllm.envs as envs
11
from vllm.model_executor.custom_op import CustomOp
zhuwenwen's avatar
zhuwenwen committed
12

13
from vllm.platforms import current_platform
14
from vllm.utils import direct_register_custom_op
15
16
17


def is_rocm_aiter_rmsnorm_enabled() -> bool:
18
    return envs.VLLM_ROCM_USE_AITER_RMSNORM \
19
20
21
22
23
24
25
        and envs.VLLM_ROCM_USE_AITER


def rms_norm(x: torch.Tensor, weight: torch.Tensor,
             variance_epsilon: float) -> torch.Tensor:
    from vllm import _custom_ops as ops
    out = torch.empty_like(x)
zhuwenwen's avatar
zhuwenwen committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    if envs.VLLM_USE_OPT_OP:
        ops.rms_norm_opt(
            out,
            x,
            weight,
            variance_epsilon,
        )
    else:
        ops.rms_norm(
            out,
            x,
            weight,
            variance_epsilon,
        )
40
41
42
43
44
    return out


def fused_add_rms_norm(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
45
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
46
    from vllm import _custom_ops as ops
zhuwenwen's avatar
zhuwenwen committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    if envs.VLLM_USE_OPT_OP:
        ops.fused_add_rms_norm_opt(
            x,
            residual,
            weight,
            variance_epsilon,
        )
    else:
        ops.fused_add_rms_norm(
            x,
            residual,
            weight,
            variance_epsilon,
        )
61
62
63
    return x, residual


64
65
66
67
68
69
70
71
72
73
74
75
76
77
def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
              variance_epsilon: float) -> torch.Tensor:
    from vllm import _custom_ops as ops
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
        weight,
        bias,
        variance_epsilon,
    )
    return out


78
79
def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor,
                             variance_epsilon: float) -> torch.Tensor:
80
    import aiter as rocm_aiter
81
82
83
84
85
86
    if x.dim() > 2:
        x_original_shape = x.shape
        x = x.reshape(-1, x_original_shape[-1])
        x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
        return x.reshape(x_original_shape)

87
88
89
    return rocm_aiter.rms_norm(x, weight, variance_epsilon)


90
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
91
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
92
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
93
94
95

    import aiter as rocm_aiter

96
97
    residual_out = torch.empty_like(residual)
    output = torch.empty_like(x)
98
    rocm_aiter.rmsnorm2d_fwd_with_add(
99
        output,  # output
100
101
        x,  # input
        residual,  # residual input
102
        residual_out,  # residual output
103
104
105
        weight,
        variance_epsilon,
    )
106
    return output, residual_out
107
108


109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def rocm_aiter_rms_norm_fake(x: torch.Tensor, weight: torch.Tensor,
                             variance_epsilon: float) -> torch.Tensor:
    return torch.empty_like(x)


def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
    return torch.empty_like(x), torch.empty_like(residual)


if current_platform.is_rocm():
    direct_register_custom_op(
        op_name="rocm_aiter_rms_norm",
        op_func=rocm_aiter_rms_norm_impl,
        fake_impl=rocm_aiter_rms_norm_fake,
    )

    direct_register_custom_op(
        op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
        op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
        fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
    )

133

134
def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
135
136
137
138
    # use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
    #     torch.float16, torch.bfloat16
    # ]
    use_aiter = False
139
140
141
142
143

    if use_aiter and with_fused_add:
        return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
    if use_aiter:
        return torch.ops.vllm.rocm_aiter_rms_norm
144

145
146
147
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
148
    return rms_norm
149
150


151
@CustomOp.register("rms_norm")
152
class RMSNorm(CustomOp):
153
154
155
156
157
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
158
159
160
161
162

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
163
        var_hidden_size: Optional[int] = None,
164
        has_weight: bool = True,
165
        dtype: Optional[torch.dtype] = None,
166
167
    ) -> None:
        super().__init__()
168
169

        self.hidden_size = hidden_size
170
        self.variance_epsilon = eps
171
172
        self.variance_size_override = (None if var_hidden_size == hidden_size
                                       else var_hidden_size)
173
        self.has_weight = has_weight
174
175
176
177
        if dtype is not None:
            self.weight = torch.ones(hidden_size, dtype=dtype)
        else:
            self.weight = torch.ones(hidden_size)
178
179
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
180
181
182
183
184
185
186
        weight_dtype = self.weight.data.dtype

        if current_platform.is_rocm():
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
                with_fused_add=False, dtype=weight_dtype)
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
                with_fused_add=True, dtype=weight_dtype)
187

188
    def forward_native(
189
190
191
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
192
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
193
194
195
196
197
198
199
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        hidden_size = x.shape[-1]
        if hidden_size != self.hidden_size:
            raise ValueError("Expected hidden_size to be "
                             f"{self.hidden_size}, but found: {hidden_size}")

        if self.variance_size_override is None:
            x_var = x
        else:
            if hidden_size < self.variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
                    f"{self.variance_size_override}, but found: {hidden_size}")

            x_var = x[:, :, :self.variance_size_override]

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

217
        x = x * torch.rsqrt(variance + self.variance_epsilon)
218
219
220
        x = x.to(orig_dtype)
        if self.has_weight:
            x = x * self.weight
221
222
223
224
225
        if residual is None:
            return x
        else:
            return x, residual

226
    def forward_cuda(
227
228
229
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
230
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
231
232
233
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

234
        add_residual = residual is not None
235
236
237
238
239
        if add_residual:
            return fused_add_rms_norm(x, residual, self.weight.data,
                                      self.variance_epsilon)
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)
240

241
242
243
244
245
246
247
    def forward_hip(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
248

249
        add_residual = residual is not None
250
        if add_residual:
251
252
            return self.rocm_norm_func_with_add(x, residual, self.weight.data,
                                                self.variance_epsilon)
zhuwenwen's avatar
zhuwenwen committed
253
        else:
zhuwenwen's avatar
zhuwenwen committed
254
255
            return self.rocm_norm_func(x, self.weight.data,
                                       self.variance_epsilon)
zhuwenwen's avatar
zhuwenwen committed
256
257
258
259
260
261
262
263
264
265
266
        
    def forward_apex(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        from apex.normalization.fused_layer_norm import fused_rms_norm_affine
        add_residual = residual is not None
        norm_func = dispatch_cuda_rmsnorm_func(add_residual)

        if add_residual:
zhuwenwen's avatar
zhuwenwen committed
267
268
            return self.rocm_norm_func_with_add(x, residual, self.weight.data,
                                                self.variance_epsilon)
zhuwenwen's avatar
zhuwenwen committed
269
270
        else:
            return fused_rms_norm_affine(x, self.weight.data, torch.Size((x.shape[-1],)), self.variance_epsilon)
271

272
273
274
275
    def forward_xpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
276
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
277
278
279
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

280
281
282
283
284
285
286
287
288
289
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
290
        return ops.rms_norm(
291
292
293
294
295
            x,
            self.weight.data,
            self.variance_epsilon,
        )

296
297
298
299
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
300
301


302
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

320
321
322
323
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
324
        x: torch.Tensor,
325
        residual: Optional[torch.Tensor],
326
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
327
328
329
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
330
331
332
333
            if orig_dtype == torch.float16:
                x = x + residual.float()
            else:
                x = x + residual
Woosuk Kwon's avatar
Woosuk Kwon committed
334
335
336
337
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
338
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
339
340
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
341
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
342
343
344
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

345
346
347
348
    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
349
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
350
351
352
353
        """PyTorch-native implementation equivalent to forward()."""
        return self.forward_static(self.weight.data, self.variance_epsilon, x,
                                   residual)

Woosuk Kwon's avatar
Woosuk Kwon committed
354
355
356
357
    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
358
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
359
360
361
362
363
364
365
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
                self.forward_static)
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
366
        return self.forward_native(x, residual)
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411


@CustomOp.register("poly_norm")
class PolyNorm(CustomOp):
    """Polynomial normalization.

    Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b
    where w_n is the learned weight and b is the bias.
    Refer to https://arxiv.org/html/2411.03884v1
    """

    def __init__(
        self,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(3) / 3)
        self.bias = torch.nn.Parameter(torch.zeros(1))
        self.variance_epsilon = eps

    def _norm(self, x):
        return x / torch.sqrt(
            x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)

    def forward_native(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward().

        Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
        """

        orig_dtype = x.dtype
        x_float = x.to(torch.float32)
        output = (self.weight[0] * self._norm(x_float**3) +
                  self.weight[1] * self._norm(x_float**2) +
                  self.weight[2] * self._norm(x_float) + self.bias)
        return output.to(orig_dtype)

    def forward_cuda(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        return poly_norm(x, self.weight, self.bias, self.variance_epsilon)
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428


class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
        return F.layer_norm(x.float(), (self.dim, ), self.weight, self.bias,
                            self.eps).type_as(x)