layernorm.py 11.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
zhuwenwen's avatar
zhuwenwen committed
4
from typing import Optional, Union, Tuple
5

6
7
8
import torch
import torch.nn as nn

zhuwenwen's avatar
zhuwenwen committed
9
import vllm.envs as envs
10
from vllm.model_executor.custom_op import CustomOp
zhuwenwen's avatar
zhuwenwen committed
11

12
from vllm.platforms import current_platform
13
from vllm.utils import direct_register_custom_op
14
15
16
17
18
19
20
21
22
23
24
25


def is_rocm_aiter_rmsnorm_enabled() -> bool:
    return current_platform.is_rocm() \
        and envs.VLLM_ROCM_USE_AITER_RMSNORM \
        and envs.VLLM_ROCM_USE_AITER


def rms_norm(x: torch.Tensor, weight: torch.Tensor,
             variance_epsilon: float) -> torch.Tensor:
    from vllm import _custom_ops as ops
    out = torch.empty_like(x)
zhuwenwen's avatar
zhuwenwen committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    if envs.VLLM_USE_OPT_OP:
        ops.rms_norm_opt(
            out,
            x,
            weight,
            variance_epsilon,
        )
    else:
        ops.rms_norm(
            out,
            x,
            weight,
            variance_epsilon,
        )
40
41
42
    return out


43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def rms_norm_opt(x: torch.Tensor, weight: torch.Tensor,
             variance_epsilon: float) -> torch.Tensor:
    from vllm import _custom_ops as ops
    from lightop import fused_rms_norm_contiguous
    out = torch.empty_like(x)
    fused_rms_norm_contiguous(
        out,
        x,
        weight,
        variance_epsilon,
    )
    return out


def rms_norm_opt_fake(x: torch.Tensor, weight: torch.Tensor,
                      variance_epsilon: float) -> torch.Tensor:
    return torch.empty_like(x)


direct_register_custom_op(
    op_name="rms_norm_opt",
    op_func=rms_norm_opt,
    mutates_args=[],
    fake_impl=rms_norm_opt_fake,
)


70
71
def fused_add_rms_norm(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
72
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
73
    from vllm import _custom_ops as ops
zhuwenwen's avatar
zhuwenwen committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    if envs.VLLM_USE_OPT_OP:
        ops.fused_add_rms_norm_opt(
            x,
            residual,
            weight,
            variance_epsilon,
        )
    else:
        ops.fused_add_rms_norm(
            x,
            residual,
            weight,
            variance_epsilon,
        )
88
89
90
91
92
93
    return x, residual


def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
                        variance_epsilon: float) -> torch.Tensor:
    import aiter as rocm_aiter
94
95
96
97
98
99
    if x.dim() > 2:
        x_original_shape = x.shape
        x = x.reshape(-1, x_original_shape[-1])
        x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
        return x.reshape(x_original_shape)

100
101
102
103
104
    return rocm_aiter.rms_norm(x, weight, variance_epsilon)


def rocm_aiter_fused_add_rms_norm(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
105
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
106
107
108

    import aiter as rocm_aiter

109
110
    residual_out = torch.empty_like(residual)
    output = torch.empty_like(x)
111
    rocm_aiter.rmsnorm2d_fwd_with_add(
112
        output,  # output
113
114
        x,  # input
        residual,  # residual input
115
        residual_out,  # residual output
116
117
118
        weight,
        variance_epsilon,
    )
119
    return output, residual_out
120
121
122
123
124
125
126
127
128
129
130


def dispatch_cuda_rmsnorm_func(add_residual: bool):
    if add_residual:
        if is_rocm_aiter_rmsnorm_enabled():
            return rocm_aiter_fused_add_rms_norm
        return fused_add_rms_norm

    if is_rocm_aiter_rmsnorm_enabled():
        return rocm_aiter_rms_norm
    return rms_norm
131
132


133
@CustomOp.register("rms_norm")
134
class RMSNorm(CustomOp):
135
136
137
138
139
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
140
141
142
143
144

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
145
        var_hidden_size: Optional[int] = None,
146
        has_weight: bool = True,
147
        dtype: Optional[torch.dtype] = None,
148
149
    ) -> None:
        super().__init__()
150
151

        self.hidden_size = hidden_size
152
        self.variance_epsilon = eps
153
154
        self.variance_size_override = (None if var_hidden_size == hidden_size
                                       else var_hidden_size)
155
        self.has_weight = has_weight
156
157
158
159
        if dtype is not None:
            self.weight = torch.ones(hidden_size, dtype=dtype)
        else:
            self.weight = torch.ones(hidden_size)
160
161
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
162

163
    def forward_native(
164
165
166
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
167
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
168
169
170
171
172
173
174
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        hidden_size = x.shape[-1]
        if hidden_size != self.hidden_size:
            raise ValueError("Expected hidden_size to be "
                             f"{self.hidden_size}, but found: {hidden_size}")

        if self.variance_size_override is None:
            x_var = x
        else:
            if hidden_size < self.variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
                    f"{self.variance_size_override}, but found: {hidden_size}")

            x_var = x[:, :, :self.variance_size_override]

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

192
        x = x * torch.rsqrt(variance + self.variance_epsilon)
193
194
195
        x = x.to(orig_dtype)
        if self.has_weight:
            x = x * self.weight
196
197
198
199
200
        if residual is None:
            return x
        else:
            return x, residual

201
    def forward_cuda(
202
203
204
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
205
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
206
207
208
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

209
210
        add_residual = residual is not None
        norm_func = dispatch_cuda_rmsnorm_func(add_residual)
211

212
213
214
        if add_residual:
            return norm_func(x, residual, self.weight.data,
                             self.variance_epsilon)
zhuwenwen's avatar
zhuwenwen committed
215
        else:
216
            return norm_func(x, self.weight.data, self.variance_epsilon)
zhuwenwen's avatar
zhuwenwen committed
217
        
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    def forward_cuda_opt(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

        add_residual = residual is not None
        norm_func = dispatch_cuda_rmsnorm_func(add_residual)

        if add_residual:
            return norm_func(x, residual, self.weight.data,
                            self.variance_epsilon)
        else:
            return torch.ops.vllm.rms_norm_opt(x, self.weight.data, self.variance_epsilon)
        
zhuwenwen's avatar
zhuwenwen committed
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    def forward_apex(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        from apex.normalization.fused_layer_norm import fused_rms_norm_affine
        add_residual = residual is not None
        norm_func = dispatch_cuda_rmsnorm_func(add_residual)

        if add_residual:
            return norm_func(x, residual, self.weight.data,
                             self.variance_epsilon)
        else:
            return fused_rms_norm_affine(x, self.weight.data, torch.Size((x.shape[-1],)), self.variance_epsilon)
249

250
251
252
253
    def forward_hpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
254
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
255
256
        from vllm_hpu_extension.kernels import rms_norm
        HPUFusedRMSNorm = rms_norm()
257
258
259
260
261
262
263
264
265
266
267
268
269
        if HPUFusedRMSNorm is None:
            return self.forward_native(x, residual)
        if residual is not None:
            orig_shape = x.shape
            residual += x.view(residual.shape)
            # Note: HPUFusedRMSNorm requires 3D tensors as inputs
            x = HPUFusedRMSNorm.apply(residual, self.weight,
                                      self.variance_epsilon)
            return x.view(orig_shape), residual

        x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
        return x

270
271
272
273
    def forward_xpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
274
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
275
276
277
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

278
279
280
281
282
283
284
285
286
287
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
288
        return ops.rms_norm(
289
290
291
292
293
            x,
            self.weight.data,
            self.variance_epsilon,
        )

294
295
296
297
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299


300
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

318
319
320
321
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
322
        x: torch.Tensor,
323
        residual: Optional[torch.Tensor],
324
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
325
326
327
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
328
329
330
331
            if orig_dtype == torch.float16:
                x = x + residual.float()
            else:
                x = x + residual
Woosuk Kwon's avatar
Woosuk Kwon committed
332
333
334
335
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
336
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
337
338
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
339
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
340
341
342
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

343
344
345
346
    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
347
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
348
349
350
351
        """PyTorch-native implementation equivalent to forward()."""
        return self.forward_static(self.weight.data, self.variance_epsilon, x,
                                   residual)

Woosuk Kwon's avatar
Woosuk Kwon committed
352
353
354
355
    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
356
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
357
358
359
360
361
362
363
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
                self.forward_static)
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
364
        return self.forward_native(x, residual)