"vscode:/vscode.git/clone" did not exist on "0c1809c8065f51646c9ca6ce7911831a763a5d18"
layernorm.py 12.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
zhuwenwen's avatar
zhuwenwen committed
4
from typing import Optional, Union, Tuple
5

6
7
8
import torch
import torch.nn as nn

zhuwenwen's avatar
zhuwenwen committed
9
import vllm.envs as envs
10
from vllm.model_executor.custom_op import CustomOp
zhuwenwen's avatar
zhuwenwen committed
11

12
from vllm.platforms import current_platform
13
from vllm.utils import direct_register_custom_op
14
15
16


def is_rocm_aiter_rmsnorm_enabled() -> bool:
17
    return envs.VLLM_ROCM_USE_AITER_RMSNORM \
18
19
20
21
22
23
24
        and envs.VLLM_ROCM_USE_AITER


def rms_norm(x: torch.Tensor, weight: torch.Tensor,
             variance_epsilon: float) -> torch.Tensor:
    from vllm import _custom_ops as ops
    out = torch.empty_like(x)
zhuwenwen's avatar
zhuwenwen committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    if envs.VLLM_USE_OPT_OP:
        ops.rms_norm_opt(
            out,
            x,
            weight,
            variance_epsilon,
        )
    else:
        ops.rms_norm(
            out,
            x,
            weight,
            variance_epsilon,
        )
39
40
41
42
43
    return out


def fused_add_rms_norm(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
44
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
45
    from vllm import _custom_ops as ops
zhuwenwen's avatar
zhuwenwen committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    if envs.VLLM_USE_OPT_OP:
        ops.fused_add_rms_norm_opt(
            x,
            residual,
            weight,
            variance_epsilon,
        )
    else:
        ops.fused_add_rms_norm(
            x,
            residual,
            weight,
            variance_epsilon,
        )
60
61
62
    return x, residual


63
64
65
66
67
68
69
70
71
72
73
74
75
76
def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
              variance_epsilon: float) -> torch.Tensor:
    from vllm import _custom_ops as ops
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
        weight,
        bias,
        variance_epsilon,
    )
    return out


77
78
def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor,
                             variance_epsilon: float) -> torch.Tensor:
79
    import aiter as rocm_aiter
80
81
82
83
84
85
    if x.dim() > 2:
        x_original_shape = x.shape
        x = x.reshape(-1, x_original_shape[-1])
        x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
        return x.reshape(x_original_shape)

86
87
88
    return rocm_aiter.rms_norm(x, weight, variance_epsilon)


89
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
90
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
91
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
92
93
94

    import aiter as rocm_aiter

95
96
    residual_out = torch.empty_like(residual)
    output = torch.empty_like(x)
97
    rocm_aiter.rmsnorm2d_fwd_with_add(
98
        output,  # output
99
100
        x,  # input
        residual,  # residual input
101
        residual_out,  # residual output
102
103
104
        weight,
        variance_epsilon,
    )
105
    return output, residual_out
106
107


108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def rocm_aiter_rms_norm_fake(x: torch.Tensor, weight: torch.Tensor,
                             variance_epsilon: float) -> torch.Tensor:
    return torch.empty_like(x)


def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
    return torch.empty_like(x), torch.empty_like(residual)


if current_platform.is_rocm():
    direct_register_custom_op(
        op_name="rocm_aiter_rms_norm",
        op_func=rocm_aiter_rms_norm_impl,
        mutates_args=[],
        fake_impl=rocm_aiter_rms_norm_fake,
        dispatch_key=current_platform.dispatch_key,
    )

    direct_register_custom_op(
        op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
        op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
        mutates_args=[],
        fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
        dispatch_key=current_platform.dispatch_key,
    )

136

137
138
139
140
141
142
143
144
145
def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
    use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
        torch.float16, torch.bfloat16
    ]

    if use_aiter and with_fused_add:
        return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
    if use_aiter:
        return torch.ops.vllm.rocm_aiter_rms_norm
146

147
148
149
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
150
    return rms_norm
151
152


153
@CustomOp.register("rms_norm")
154
class RMSNorm(CustomOp):
155
156
157
158
159
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
160
161
162
163
164

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
165
        var_hidden_size: Optional[int] = None,
166
        has_weight: bool = True,
167
        dtype: Optional[torch.dtype] = None,
168
169
    ) -> None:
        super().__init__()
170
171

        self.hidden_size = hidden_size
172
        self.variance_epsilon = eps
173
174
        self.variance_size_override = (None if var_hidden_size == hidden_size
                                       else var_hidden_size)
175
        self.has_weight = has_weight
176
177
178
179
        if dtype is not None:
            self.weight = torch.ones(hidden_size, dtype=dtype)
        else:
            self.weight = torch.ones(hidden_size)
180
181
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
182
183
184
185
186
187
188
        weight_dtype = self.weight.data.dtype

        if current_platform.is_rocm():
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
                with_fused_add=False, dtype=weight_dtype)
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
                with_fused_add=True, dtype=weight_dtype)
189

190
    def forward_native(
191
192
193
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
194
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
195
196
197
198
199
200
201
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        hidden_size = x.shape[-1]
        if hidden_size != self.hidden_size:
            raise ValueError("Expected hidden_size to be "
                             f"{self.hidden_size}, but found: {hidden_size}")

        if self.variance_size_override is None:
            x_var = x
        else:
            if hidden_size < self.variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
                    f"{self.variance_size_override}, but found: {hidden_size}")

            x_var = x[:, :, :self.variance_size_override]

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

219
        x = x * torch.rsqrt(variance + self.variance_epsilon)
220
221
222
        x = x.to(orig_dtype)
        if self.has_weight:
            x = x * self.weight
223
224
225
226
227
        if residual is None:
            return x
        else:
            return x, residual

228
    def forward_cuda(
229
230
231
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
232
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
233
234
235
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

236
        add_residual = residual is not None
237
238
239
240
241
        if add_residual:
            return fused_add_rms_norm(x, residual, self.weight.data,
                                      self.variance_epsilon)
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)
242

243
244
245
246
247
248
249
    def forward_hip(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
250

251
        add_residual = residual is not None
252
        if add_residual:
253
254
            return self.rocm_norm_func_with_add(x, residual, self.weight.data,
                                                self.variance_epsilon)
zhuwenwen's avatar
zhuwenwen committed
255
        else:
256
            return norm_func(x, self.weight.data, self.variance_epsilon)
zhuwenwen's avatar
zhuwenwen committed
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        
    def forward_apex(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        from apex.normalization.fused_layer_norm import fused_rms_norm_affine
        add_residual = residual is not None
        norm_func = dispatch_cuda_rmsnorm_func(add_residual)

        if add_residual:
            return norm_func(x, residual, self.weight.data,
                             self.variance_epsilon)
        else:
            return fused_rms_norm_affine(x, self.weight.data, torch.Size((x.shape[-1],)), self.variance_epsilon)
272

273
274
275
276
    def forward_xpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
277
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
278
279
280
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

281
282
283
284
285
286
287
288
289
290
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
291
        return ops.rms_norm(
292
293
294
295
296
            x,
            self.weight.data,
            self.variance_epsilon,
        )

297
298
299
300
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302


303
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

321
322
323
324
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
325
        x: torch.Tensor,
326
        residual: Optional[torch.Tensor],
327
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
328
329
330
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
331
332
333
334
            if orig_dtype == torch.float16:
                x = x + residual.float()
            else:
                x = x + residual
Woosuk Kwon's avatar
Woosuk Kwon committed
335
336
337
338
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
339
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
340
341
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
342
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
343
344
345
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

346
347
348
349
    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
350
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
351
352
353
354
        """PyTorch-native implementation equivalent to forward()."""
        return self.forward_static(self.weight.data, self.variance_epsilon, x,
                                   residual)

Woosuk Kwon's avatar
Woosuk Kwon committed
355
356
357
358
    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
359
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
360
361
362
363
364
365
366
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
                self.forward_static)
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
367
        return self.forward_native(x, residual)
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412


@CustomOp.register("poly_norm")
class PolyNorm(CustomOp):
    """Polynomial normalization.

    Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b
    where w_n is the learned weight and b is the bias.
    Refer to https://arxiv.org/html/2411.03884v1
    """

    def __init__(
        self,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(3) / 3)
        self.bias = torch.nn.Parameter(torch.zeros(1))
        self.variance_epsilon = eps

    def _norm(self, x):
        return x / torch.sqrt(
            x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)

    def forward_native(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward().

        Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
        """

        orig_dtype = x.dtype
        x_float = x.to(torch.float32)
        output = (self.weight[0] * self._norm(x_float**3) +
                  self.weight[1] * self._norm(x_float**2) +
                  self.weight[2] * self._norm(x_float) + self.bias)
        return output.to(orig_dtype)

    def forward_cuda(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        return poly_norm(x, self.weight, self.bias, self.variance_epsilon)