layernorm.py 11.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
import vllm.envs as envs
10
from vllm.model_executor.custom_op import CustomOp
11
12
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
13
    vllm_is_batch_invariant,
14
)
15
from vllm.platforms import current_platform
16
from vllm.utils.torch_utils import direct_register_custom_op
17
18
19


def is_rocm_aiter_rmsnorm_enabled() -> bool:
20
    return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER
21
22


23
24
25
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
26
    from vllm import _custom_ops as ops
27

28
    if vllm_is_batch_invariant():
29
        return rms_norm_batch_invariant(x, weight, variance_epsilon)
30
31
32
33
34
35
36
37
38
39
40
    out = torch.empty_like(x)
    ops.rms_norm(
        out,
        x,
        weight,
        variance_epsilon,
    )
    return out


def fused_add_rms_norm(
41
42
43
44
45
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
46
    from vllm import _custom_ops as ops
47

48
    if vllm_is_batch_invariant():
49
50
51
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
52
53
54
55
56
57
58
59
60
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


61
62
63
def rocm_aiter_rms_norm_impl(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
64
    import aiter as rocm_aiter
65

66
67
68
69
70
71
    if x.dim() > 2:
        x_original_shape = x.shape
        x = x.reshape(-1, x_original_shape[-1])
        x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
        return x.reshape(x_original_shape)

72
73
74
    return rocm_aiter.rms_norm(x, weight, variance_epsilon)


75
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
76
77
78
79
80
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
81
82
    import aiter as rocm_aiter

83
84
    residual_out = torch.empty_like(residual)
    output = torch.empty_like(x)
85
    rocm_aiter.rmsnorm2d_fwd_with_add(
86
        output,  # output
87
88
        x,  # input
        residual,  # residual input
89
        residual_out,  # residual output
90
91
92
        weight,
        variance_epsilon,
    )
93
    return output, residual_out
94
95


96
97
98
def rocm_aiter_rms_norm_fake(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
99
100
101
102
    return torch.empty_like(x)


def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
103
104
105
106
107
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    return torch.empty_like(x), torch.empty_like(residual)


if current_platform.is_rocm():
    direct_register_custom_op(
        op_name="rocm_aiter_rms_norm",
        op_func=rocm_aiter_rms_norm_impl,
        fake_impl=rocm_aiter_rms_norm_fake,
    )

    direct_register_custom_op(
        op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
        op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
        fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
    )


def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
    use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
127
128
        torch.float16,
        torch.bfloat16,
129
130
131
132
133
134
    ]

    if use_aiter and with_fused_add:
        return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
    if use_aiter:
        return torch.ops.vllm.rocm_aiter_rms_norm
135

136
137
138
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
139
    return rms_norm
140
141


142
@CustomOp.register("rms_norm")
143
class RMSNorm(CustomOp):
144
145
146
147
148
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
149
150
151
152
153

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
154
        var_hidden_size: int | None = None,
155
        has_weight: bool = True,
156
        dtype: torch.dtype | None = None,
157
158
    ) -> None:
        super().__init__()
159
160

        self.hidden_size = hidden_size
161
        self.variance_epsilon = eps
162
163
164
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
165
        weight_dtype = dtype or torch.get_default_dtype()
166
        self.has_weight = has_weight
167
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
168
169
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
170
171
172

        if current_platform.is_rocm():
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
173
174
                with_fused_add=False, dtype=weight_dtype
            )
175
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
176
177
                with_fused_add=True, dtype=weight_dtype
            )
178

179
180
    @staticmethod
    def forward_static(
181
        x: torch.Tensor,
182
183
184
185
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
186
        residual: torch.Tensor | None = None,
187
        variance_size_override: int | None = None,
188
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
189
190
191
        """PyTorch-native implementation equivalent to forward()."""
        x = x.to(torch.float32)
        if residual is not None:
192
193
194
195
            # residual promoted f16->f32 automatically,
            # otherwise Inductor eliminates the casts to and from f16,
            # increasing memory usage (and complicating pattern matching)
            x = x + residual
196
197
            residual = x.to(orig_dtype)

198
        if x.shape[-1] != hidden_size:
199
            raise ValueError(
200
                f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
201
            )
202

203
        if variance_size_override is None:
204
205
            x_var = x
        else:
206
            if hidden_size < variance_size_override:
207
208
                raise ValueError(
                    "Expected hidden_size to be at least "
209
                    f"{variance_size_override}, but found: {hidden_size}"
210
                )
211

212
            x_var = x[:, :, :variance_size_override]
213
214
215

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

216
        x = x * torch.rsqrt(variance + variance_epsilon)
217
        x = x.to(orig_dtype)
218
219
        if weight is not None:
            x = x * weight
220
221
222
223
224
        if residual is None:
            return x
        else:
            return x, residual

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

242
    def forward_cuda(
243
244
        self,
        x: torch.Tensor,
245
246
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
247
248
249
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

250
        add_residual = residual is not None
251
        if add_residual:
252
253
254
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
255
256
257
258
259
260
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)

    def forward_hip(
        self,
        x: torch.Tensor,
261
262
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
263
264
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
265

266
        add_residual = residual is not None
267
        if add_residual:
268
269
270
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
271
        else:
272
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
273

274
275
276
    def forward_xpu(
        self,
        x: torch.Tensor,
277
278
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
279
280
281
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

282
283
284
285
286
287
288
289
290
291
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
292
        return ops.rms_norm(
293
294
295
296
297
            x,
            self.weight.data,
            self.variance_epsilon,
        )

298
299
300
301
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
302
303


304
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

322
323
324
325
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
326
        x: torch.Tensor,
327
328
        residual: torch.Tensor | None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
329
330
331
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
332
333
334
335
336
            x = (
                x.float() + residual.float()
                if orig_dtype == torch.float16
                else x + residual
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
337
338
339
340
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
341
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
342
343
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
344
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
345
346
347
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

348
349
350
    def forward_native(
        self,
        x: torch.Tensor,
351
352
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
353
        """PyTorch-native implementation equivalent to forward()."""
354
        return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)
355

Woosuk Kwon's avatar
Woosuk Kwon committed
356
357
358
    def forward_cuda(
        self,
        x: torch.Tensor,
359
360
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
361
362
363
364
365
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
366
367
                self.forward_static
            )
368
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
369
        return self.forward_native(x, residual)
370
371


372
373
374
375
376
377
378
379
380
381
382
383
384
class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
385
386
387
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)