layernorm.py 9.01 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4
from typing import Optional, Union
5

6
7
8
import torch
import torch.nn as nn

9
import vllm.envs as envs
10
from vllm.model_executor.custom_op import CustomOp
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from vllm.platforms import current_platform


def is_rocm_aiter_rmsnorm_enabled() -> bool:
    return current_platform.is_rocm() \
        and envs.VLLM_ROCM_USE_AITER_RMSNORM \
        and envs.VLLM_ROCM_USE_AITER


def rms_norm(x: torch.Tensor, weight: torch.Tensor,
             variance_epsilon: float) -> torch.Tensor:
    from vllm import _custom_ops as ops
    out = torch.empty_like(x)
    ops.rms_norm(
        out,
        x,
        weight,
        variance_epsilon,
    )
    return out


def fused_add_rms_norm(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
35
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    from vllm import _custom_ops as ops
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
                        variance_epsilon: float) -> torch.Tensor:

    import aiter as rocm_aiter
50
51
52
53
54
55
    if x.dim() > 2:
        x_original_shape = x.shape
        x = x.reshape(-1, x_original_shape[-1])
        x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
        return x.reshape(x_original_shape)

56
57
58
59
60
    return rocm_aiter.rms_norm(x, weight, variance_epsilon)


def rocm_aiter_fused_add_rms_norm(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
61
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
62
63
64

    import aiter as rocm_aiter

65
66
    residual_out = torch.empty_like(residual)
    output = torch.empty_like(x)
67
    rocm_aiter.rmsnorm2d_fwd_with_add(
68
        output,  # output
69
70
        x,  # input
        residual,  # residual input
71
        residual_out,  # residual output
72
73
74
        weight,
        variance_epsilon,
    )
75
    return output, residual_out
76
77
78
79
80
81
82
83
84
85
86


def dispatch_cuda_rmsnorm_func(add_residual: bool):
    if add_residual:
        if is_rocm_aiter_rmsnorm_enabled():
            return rocm_aiter_fused_add_rms_norm
        return fused_add_rms_norm

    if is_rocm_aiter_rmsnorm_enabled():
        return rocm_aiter_rms_norm
    return rms_norm
87
88


89
@CustomOp.register("rms_norm")
90
class RMSNorm(CustomOp):
91
92
93
94
95
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
96
97
98
99
100

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
101
        var_hidden_size: Optional[int] = None,
102
        has_weight: bool = True,
103
        dtype: Optional[torch.dtype] = None,
104
105
    ) -> None:
        super().__init__()
106
107

        self.hidden_size = hidden_size
108
        self.variance_epsilon = eps
109
110
        self.variance_size_override = (None if var_hidden_size == hidden_size
                                       else var_hidden_size)
111
        self.has_weight = has_weight
112
113
114
115
        if dtype is not None:
            self.weight = torch.ones(hidden_size, dtype=dtype)
        else:
            self.weight = torch.ones(hidden_size)
116
117
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
118

119
    def forward_native(
120
121
122
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
123
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
124
125
126
127
128
129
130
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        hidden_size = x.shape[-1]
        if hidden_size != self.hidden_size:
            raise ValueError("Expected hidden_size to be "
                             f"{self.hidden_size}, but found: {hidden_size}")

        if self.variance_size_override is None:
            x_var = x
        else:
            if hidden_size < self.variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
                    f"{self.variance_size_override}, but found: {hidden_size}")

            x_var = x[:, :, :self.variance_size_override]

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

148
        x = x * torch.rsqrt(variance + self.variance_epsilon)
149
150
151
        x = x.to(orig_dtype)
        if self.has_weight:
            x = x * self.weight
152
153
154
155
156
        if residual is None:
            return x
        else:
            return x, residual

157
    def forward_cuda(
158
159
160
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
161
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
162
163
164
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

165
166
        add_residual = residual is not None
        norm_func = dispatch_cuda_rmsnorm_func(add_residual)
167

168
169
170
171
172
        if add_residual:
            return norm_func(x, residual, self.weight.data,
                             self.variance_epsilon)
        else:
            return norm_func(x, self.weight.data, self.variance_epsilon)
173

174
175
176
177
    def forward_hpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
178
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
179
180
        from vllm_hpu_extension.kernels import rms_norm
        HPUFusedRMSNorm = rms_norm()
181
182
183
184
185
186
187
188
189
190
191
192
193
        if HPUFusedRMSNorm is None:
            return self.forward_native(x, residual)
        if residual is not None:
            orig_shape = x.shape
            residual += x.view(residual.shape)
            # Note: HPUFusedRMSNorm requires 3D tensors as inputs
            x = HPUFusedRMSNorm.apply(residual, self.weight,
                                      self.variance_epsilon)
            return x.view(orig_shape), residual

        x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
        return x

194
195
196
197
    def forward_xpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
198
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
199
200
201
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

202
203
204
205
206
207
208
209
210
211
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
212
        return ops.rms_norm(
213
214
215
216
217
            x,
            self.weight.data,
            self.variance_epsilon,
        )

218
219
220
221
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
222
223


224
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

242
243
244
245
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
246
        x: torch.Tensor,
247
        residual: Optional[torch.Tensor],
248
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
249
250
251
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
252
253
254
255
            if orig_dtype == torch.float16:
                x = x + residual.float()
            else:
                x = x + residual
Woosuk Kwon's avatar
Woosuk Kwon committed
256
257
258
259
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
260
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
261
262
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
263
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
264
265
266
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

267
268
269
270
    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
271
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
272
273
274
275
        """PyTorch-native implementation equivalent to forward()."""
        return self.forward_static(self.weight.data, self.variance_epsilon, x,
                                   residual)

Woosuk Kwon's avatar
Woosuk Kwon committed
276
277
278
279
    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
280
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
281
282
283
284
285
286
287
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
                self.forward_static)
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
288
        return self.forward_native(x, residual)