"vllm/vscode:/vscode.git/clone" did not exist on "4934d492744d14104353b8236ef8a0405edf1622"
layernorm.py 10.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4
from typing import Optional, Union
5

6
7
8
import torch
import torch.nn as nn

9
import vllm.envs as envs
10
from vllm.model_executor.custom_op import CustomOp
11
from vllm.platforms import current_platform
12
from vllm.utils import direct_register_custom_op
13
14
15


def is_rocm_aiter_rmsnorm_enabled() -> bool:
16
    return envs.VLLM_ROCM_USE_AITER_RMSNORM \
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
        and envs.VLLM_ROCM_USE_AITER


def rms_norm(x: torch.Tensor, weight: torch.Tensor,
             variance_epsilon: float) -> torch.Tensor:
    from vllm import _custom_ops as ops
    out = torch.empty_like(x)
    ops.rms_norm(
        out,
        x,
        weight,
        variance_epsilon,
    )
    return out


def fused_add_rms_norm(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
35
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
36
37
38
39
40
41
42
43
44
45
    from vllm import _custom_ops as ops
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


46
47
def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor,
                             variance_epsilon: float) -> torch.Tensor:
48
    import aiter as rocm_aiter
49
50
51
52
53
54
    if x.dim() > 2:
        x_original_shape = x.shape
        x = x.reshape(-1, x_original_shape[-1])
        x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
        return x.reshape(x_original_shape)

55
56
57
    return rocm_aiter.rms_norm(x, weight, variance_epsilon)


58
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
59
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
60
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
61
62
63

    import aiter as rocm_aiter

64
65
    residual_out = torch.empty_like(residual)
    output = torch.empty_like(x)
66
    rocm_aiter.rmsnorm2d_fwd_with_add(
67
        output,  # output
68
69
        x,  # input
        residual,  # residual input
70
        residual_out,  # residual output
71
72
73
        weight,
        variance_epsilon,
    )
74
    return output, residual_out
75
76


77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def rocm_aiter_rms_norm_fake(x: torch.Tensor, weight: torch.Tensor,
                             variance_epsilon: float) -> torch.Tensor:
    return torch.empty_like(x)


def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
    return torch.empty_like(x), torch.empty_like(residual)


if current_platform.is_rocm():
    direct_register_custom_op(
        op_name="rocm_aiter_rms_norm",
        op_func=rocm_aiter_rms_norm_impl,
        mutates_args=[],
        fake_impl=rocm_aiter_rms_norm_fake,
        dispatch_key=current_platform.dispatch_key,
    )

    direct_register_custom_op(
        op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
        op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
        mutates_args=[],
        fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
        dispatch_key=current_platform.dispatch_key,
    )


def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
    use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
        torch.float16, torch.bfloat16
    ]

    if use_aiter and with_fused_add:
        return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
    if use_aiter:
        return torch.ops.vllm.rocm_aiter_rms_norm
115

116
117
118
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
119
    return rms_norm
120
121


122
@CustomOp.register("rms_norm")
123
class RMSNorm(CustomOp):
124
125
126
127
128
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
129
130
131
132
133

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
134
        var_hidden_size: Optional[int] = None,
135
        has_weight: bool = True,
136
        dtype: Optional[torch.dtype] = None,
137
138
    ) -> None:
        super().__init__()
139
140

        self.hidden_size = hidden_size
141
        self.variance_epsilon = eps
142
143
        self.variance_size_override = (None if var_hidden_size == hidden_size
                                       else var_hidden_size)
144
        self.has_weight = has_weight
145
146
147
148
        if dtype is not None:
            self.weight = torch.ones(hidden_size, dtype=dtype)
        else:
            self.weight = torch.ones(hidden_size)
149
150
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
151
152
153
154
155
156
157
        weight_dtype = self.weight.data.dtype

        if current_platform.is_rocm():
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
                with_fused_add=False, dtype=weight_dtype)
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
                with_fused_add=True, dtype=weight_dtype)
158

159
    def forward_native(
160
161
162
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
163
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
164
165
166
167
168
169
170
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        hidden_size = x.shape[-1]
        if hidden_size != self.hidden_size:
            raise ValueError("Expected hidden_size to be "
                             f"{self.hidden_size}, but found: {hidden_size}")

        if self.variance_size_override is None:
            x_var = x
        else:
            if hidden_size < self.variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
                    f"{self.variance_size_override}, but found: {hidden_size}")

            x_var = x[:, :, :self.variance_size_override]

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

188
        x = x * torch.rsqrt(variance + self.variance_epsilon)
189
190
191
        x = x.to(orig_dtype)
        if self.has_weight:
            x = x * self.weight
192
193
194
195
196
        if residual is None:
            return x
        else:
            return x, residual

197
    def forward_cuda(
198
199
200
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
201
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
202
203
204
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

205
        add_residual = residual is not None
206
207
208
209
210
211
212
213
214
215
216
217
218
        if add_residual:
            return fused_add_rms_norm(x, residual, self.weight.data,
                                      self.variance_epsilon)
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)

    def forward_hip(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
219

220
        add_residual = residual is not None
221
        if add_residual:
222
223
            return self.rocm_norm_func_with_add(x, residual, self.weight.data,
                                                self.variance_epsilon)
224
        else:
225
226
            return self.rocm_norm_func(x, self.weight.data,
                                       self.variance_epsilon)
227

228
229
230
231
    def forward_xpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
232
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
233
234
235
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

236
237
238
239
240
241
242
243
244
245
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
246
        return ops.rms_norm(
247
248
249
250
251
            x,
            self.weight.data,
            self.variance_epsilon,
        )

252
253
254
255
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
256
257


258
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

276
277
278
279
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
280
        x: torch.Tensor,
281
        residual: Optional[torch.Tensor],
282
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
283
284
285
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
286
287
288
289
            if orig_dtype == torch.float16:
                x = x + residual.float()
            else:
                x = x + residual
Woosuk Kwon's avatar
Woosuk Kwon committed
290
291
292
293
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
294
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
295
296
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
297
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299
300
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

301
302
303
304
    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
305
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
306
307
308
309
        """PyTorch-native implementation equivalent to forward()."""
        return self.forward_static(self.weight.data, self.variance_epsilon, x,
                                   residual)

Woosuk Kwon's avatar
Woosuk Kwon committed
310
311
312
313
    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
314
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
315
316
317
318
319
320
321
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
                self.forward_static)
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
322
        return self.forward_native(x, residual)