layernorm.py 15 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
import vllm.envs as envs
10
from vllm.model_executor.custom_op import CustomOp
11
12
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
13
    vllm_is_batch_invariant,
14
)
15
from vllm.platforms import current_platform
16
from vllm.utils.torch_utils import direct_register_custom_op
17
18
19


def is_rocm_aiter_rmsnorm_enabled() -> bool:
20
    return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER
21
22


23
24
25
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
26
    from vllm import _custom_ops as ops
27

28
    if vllm_is_batch_invariant():
29
        return rms_norm_batch_invariant(x, weight, variance_epsilon)
30
31
32
33
34
35
36
37
38
39
40
    out = torch.empty_like(x)
    ops.rms_norm(
        out,
        x,
        weight,
        variance_epsilon,
    )
    return out


def fused_add_rms_norm(
41
42
43
44
45
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
46
    from vllm import _custom_ops as ops
47

48
    if vllm_is_batch_invariant():
49
50
51
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
52
53
54
55
56
57
58
59
60
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


61
62
63
def rocm_aiter_rms_norm_impl(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
64
    import aiter as rocm_aiter
65

66
67
68
69
70
71
    if x.dim() > 2:
        x_original_shape = x.shape
        x = x.reshape(-1, x_original_shape[-1])
        x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
        return x.reshape(x_original_shape)

72
73
74
    return rocm_aiter.rms_norm(x, weight, variance_epsilon)


75
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
76
77
78
79
80
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
81
82
    import aiter as rocm_aiter

83
84
    residual_out = torch.empty_like(residual)
    output = torch.empty_like(x)
85
    rocm_aiter.rmsnorm2d_fwd_with_add(
86
        output,  # output
87
88
        x,  # input
        residual,  # residual input
89
        residual_out,  # residual output
90
91
92
        weight,
        variance_epsilon,
    )
93
    return output, residual_out
94
95


96
97
98
def rocm_aiter_rms_norm_fake(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
99
100
101
102
    return torch.empty_like(x)


def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
103
104
105
106
107
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    return torch.empty_like(x), torch.empty_like(residual)


if current_platform.is_rocm():
    direct_register_custom_op(
        op_name="rocm_aiter_rms_norm",
        op_func=rocm_aiter_rms_norm_impl,
        fake_impl=rocm_aiter_rms_norm_fake,
    )

    direct_register_custom_op(
        op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
        op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
        fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
    )


def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
    use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
127
128
        torch.float16,
        torch.bfloat16,
129
130
131
132
133
134
    ]

    if use_aiter and with_fused_add:
        return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
    if use_aiter:
        return torch.ops.vllm.rocm_aiter_rms_norm
135

136
137
138
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
139
    return rms_norm
140
141


142
@CustomOp.register("rms_norm")
143
class RMSNorm(CustomOp):
144
145
146
147
148
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
149
150
151
152
153

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
154
        var_hidden_size: int | None = None,
155
        has_weight: bool = True,
156
        dtype: torch.dtype | None = None,
157
158
    ) -> None:
        super().__init__()
159
160

        self.hidden_size = hidden_size
161
        self.variance_epsilon = eps
162
163
164
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
165
        weight_dtype = dtype or torch.get_default_dtype()
166
        self.has_weight = has_weight
167
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
168
169
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
170
171
172

        if current_platform.is_rocm():
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
173
174
                with_fused_add=False, dtype=weight_dtype
            )
175
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
176
177
                with_fused_add=True, dtype=weight_dtype
            )
178

179
180
    @staticmethod
    def forward_static(
181
        x: torch.Tensor,
182
183
184
185
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
186
        residual: torch.Tensor | None = None,
187
        variance_size_override: int | None = None,
188
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
189
190
191
        """PyTorch-native implementation equivalent to forward()."""
        x = x.to(torch.float32)
        if residual is not None:
192
193
194
195
            # residual promoted f16->f32 automatically,
            # otherwise Inductor eliminates the casts to and from f16,
            # increasing memory usage (and complicating pattern matching)
            x = x + residual
196
197
            residual = x.to(orig_dtype)

198
        if x.shape[-1] != hidden_size:
199
            raise ValueError(
200
                f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
201
            )
202

203
        if variance_size_override is None:
204
205
            x_var = x
        else:
206
            if hidden_size < variance_size_override:
207
208
                raise ValueError(
                    "Expected hidden_size to be at least "
209
                    f"{variance_size_override}, but found: {hidden_size}"
210
                )
211

212
            x_var = x[:, :, :variance_size_override]
213
214
215

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

216
        x = x * torch.rsqrt(variance + variance_epsilon)
217
        x = x.to(orig_dtype)
218
219
        if weight is not None:
            x = x * weight
220
221
222
223
224
        if residual is None:
            return x
        else:
            return x, residual

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

242
    def forward_cuda(
243
244
        self,
        x: torch.Tensor,
245
246
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
247
248
249
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

250
        add_residual = residual is not None
251
        if add_residual:
252
253
254
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
255
256
257
258
259
260
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)

    def forward_hip(
        self,
        x: torch.Tensor,
261
262
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
263
264
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
265

266
        add_residual = residual is not None
267
        if add_residual:
268
269
270
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
271
        else:
272
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
273

274
275
276
    def forward_xpu(
        self,
        x: torch.Tensor,
277
278
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
279
280
281
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

282
283
284
285
286
287
288
289
290
291
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
292
        return ops.rms_norm(
293
294
295
296
297
            x,
            self.weight.data,
            self.variance_epsilon,
        )

298
299
300
301
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
302
303


304
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

322
323
324
325
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
326
        x: torch.Tensor,
327
328
        residual: torch.Tensor | None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
329
330
331
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
332
333
334
335
336
            x = (
                x.float() + residual.float()
                if orig_dtype == torch.float16
                else x + residual
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
337
338
339
340
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
341
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
342
343
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
344
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
345
346
347
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

348
349
350
    def forward_native(
        self,
        x: torch.Tensor,
351
352
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
353
        """PyTorch-native implementation equivalent to forward()."""
354
        return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)
355

Woosuk Kwon's avatar
Woosuk Kwon committed
356
357
358
    def forward_cuda(
        self,
        x: torch.Tensor,
359
360
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
361
362
363
364
365
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
366
367
                self.forward_static
            )
368
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
369
        return self.forward_native(x, residual)
370
371


372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
    """RMS Normalization with optional gating.

    This is a native PyTorch implementation that supports:
    - Standard RMS normalization
    - Group RMS normalization
    - Optional gating with SiLU activation
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        """Initialize RMSNormGated.

        Args:
            hidden_size: Size of the hidden dimension
            eps: Epsilon for numerical stability
            group_size: If not None, do GroupNorm with each group
                        having group_size elements.
                        group_size=None is equivalent to group_size=hidden_size
                        (i.e. there's only 1 group).
            norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                              If False and z is provided: out = norm(x * silu(z))
            device: Device to create parameters on
            dtype: Data type for parameters
        """
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)

    def forward_native(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        Native PyTorch implementation of RMS normalization with gating.

        Args:
            x: Input tensor
            z: Optional gating tensor

        Returns:
            Normalized (and optionally gated) tensor

        If z is not None:
            - norm_before_gate=True: out = norm(x) * silu(z)
            - norm_before_gate=False: out = norm(x * silu(z))
        """
        # Apply gating before normalization if needed
        if z is not None and not self.norm_before_gate:
            x = x * F.silu(z)

        # RMS Normalization
        if self.group_size is None:
            # Standard RMS norm across the last dimension
            variance = x.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x * torch.rsqrt(variance + self.eps)
            out = x_normed * self.weight
        else:
            # Group RMS norm
            from einops import rearrange

            x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
            variance = x_group.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x_group * torch.rsqrt(variance + self.eps)
            out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight

        # Apply gating after normalization if needed
        if z is not None and self.norm_before_gate:
            out = out * F.silu(z)

        return out

    def forward_cuda(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
        )


475
476
477
478
479
480
481
482
483
484
485
486
487
class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
488
489
490
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)