layernorm.py 20.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
from vllm import _oink_ops, envs
10
from vllm._aiter_ops import rocm_aiter_ops
11
from vllm.logger import init_logger
12
from vllm.model_executor.custom_op import CustomOp
13
14
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
15
    vllm_is_batch_invariant,
16
)
17
18
from vllm.platforms import current_platform

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
logger = init_logger(__name__)


def _can_view_as_2d(x: torch.Tensor) -> bool:
    """Return True if x.view(-1, x.shape[-1]) is viewable (no copy)."""
    if x.dim() < 2:
        return False
    if x.dim() == 2:
        return True
    # For a view(-1, N) to be valid, all leading dims must be contiguous with
    # respect to each other (size-1 dims are ignored).
    for dim in range(x.dim() - 1):
        # Strides for size-1 dims are irrelevant and can be arbitrary.
        if x.size(dim + 1) != 1 and x.stride(dim) != x.stride(dim + 1) * x.size(
            dim + 1
        ):
            return False
    return True


def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool:
    """Return True if x_2d meets Oink's pointer-path stride constraints."""
    if x_2d.dim() != 2:
        return False
    if x_2d.stride(1) != 1:
        return False
    # Match Oink's vectorization constraint: stride(0) divisible by 256b.
    if x_2d.dtype in (torch.float16, torch.bfloat16):
        divby = 16
    elif x_2d.dtype == torch.float32:
        divby = 8
    else:
        return False
    return (x_2d.stride(0) % divby) == 0

54

55
56
57
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
58
    from vllm import _custom_ops as ops
59

60
    if vllm_is_batch_invariant():
61
        return rms_norm_batch_invariant(x, weight, variance_epsilon)
62
63
64
65
66
67
68
69
70
71
72
    out = torch.empty_like(x)
    ops.rms_norm(
        out,
        x,
        weight,
        variance_epsilon,
    )
    return out


def fused_add_rms_norm(
73
74
75
76
77
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
78
    from vllm import _custom_ops as ops
79

80
    if vllm_is_batch_invariant():
81
82
83
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
84
85
86
87
88
89
90
91
92
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


93
94
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
95
) -> torch.Tensor:
96
    from vllm import _custom_ops as ops
97

98
99
100
101
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
102
        weight,
103
        bias,
104
105
        variance_epsilon,
    )
106
    return out
107
108


109
110
111
112
def dispatch_rocm_rmsnorm_func(
    with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
    use_aiter = use_aiter and dtype in [
113
114
        torch.float16,
        torch.bfloat16,
115
116
117
    ]

    if use_aiter and with_fused_add:
118
        return rocm_aiter_ops.rms_norm2d_with_add
119
    if use_aiter:
120
        return rocm_aiter_ops.rms_norm
121

122
123
124
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
125
    return rms_norm
126
127


128
# --8<-- [start:rms_norm]
129
@CustomOp.register("rms_norm")
130
class RMSNorm(CustomOp):
131
132
133
134
135
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
136

137
138
    # --8<-- [end:rms_norm]

139
140
141
142
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
143
        var_hidden_size: int | None = None,
144
        has_weight: bool = True,
145
        dtype: torch.dtype | None = None,
146
147
    ) -> None:
        super().__init__()
148
149

        self.hidden_size = hidden_size
150
        self.variance_epsilon = eps
151
152
153
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
154
        weight_dtype = dtype or torch.get_default_dtype()
155
        self.has_weight = has_weight
156
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
157
158
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
159
160

        if current_platform.is_rocm():
161
            aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
162
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
163
164
165
                with_fused_add=False,
                dtype=weight_dtype,
                use_aiter=aiter_rmsnorm_enabled,
166
            )
167
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
168
                with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
169
            )
170

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        # Optional: enable Oink Blackwell RMSNorm custom-op fast path on
        # compatible CUDA devices (e.g., SM100) when the external Oink
        # package is available. This is detected once at construction time
        # to avoid per-call device queries in the hot path.
        self._use_oink_rmsnorm = False
        self._use_oink_fused_add_rmsnorm = False
        if (
            not current_platform.is_rocm()
            and torch.cuda.is_available()
            and bool(getattr(envs, "VLLM_USE_OINK_OPS", False))
        ):
            # NOTE: vLLM disables custom ops by default when using Inductor.
            # If this op is disabled, CustomOp will dispatch to forward_native,
            # and the Oink path in forward_cuda will never run.
            if getattr(self._forward_method, "__func__", None) is getattr(
                self.forward_native, "__func__", None
            ):
                try:
                    from vllm.config import get_cached_compilation_config

                    custom_ops = get_cached_compilation_config().custom_ops
                except Exception:
                    custom_ops = ["<unknown>"]
                logger.warning_once(
                    "VLLM_USE_OINK_OPS=1 but the `rms_norm` custom op is "
                    "disabled (CompilationConfig.custom_ops=%s). Enable it via "
                    "`compilation_config={'custom_ops': ['none', '+rms_norm']}` "
                    "(or `['all']`) to let vLLM call into torch.ops.oink.*.",
                    custom_ops,
                )
                # Custom op disabled => forward_cuda won't run. Avoid doing any
                # external Oink initialization work in this case.
            else:
                try:
205
                    device_index = torch.accelerator.current_device_index()
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
                    if _oink_ops.is_oink_available_for_device(device_index):
                        self._use_oink_rmsnorm = True
                        self._use_oink_fused_add_rmsnorm = (
                            _oink_ops.has_fused_add_rms_norm()
                        )
                except Exception as e:
                    # If anything goes wrong (no Oink install, CPU-only env, etc.),
                    # silently fall back to the built-in RMSNorm path.
                    logger.warning_once(
                        "VLLM_USE_OINK_OPS=1 but failed to initialize Oink "
                        "RMSNorm; falling back to vLLM RMSNorm. Error: %s",
                        e,
                    )
                    self._use_oink_rmsnorm = False
                    self._use_oink_fused_add_rmsnorm = False

222
223
    @staticmethod
    def forward_static(
224
        x: torch.Tensor,
225
226
227
228
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
229
        residual: torch.Tensor | None = None,
230
        variance_size_override: int | None = None,
231
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
232
233
234
        """PyTorch-native implementation equivalent to forward()."""
        x = x.to(torch.float32)
        if residual is not None:
235
236
237
238
            # residual promoted f16->f32 automatically,
            # otherwise Inductor eliminates the casts to and from f16,
            # increasing memory usage (and complicating pattern matching)
            x = x + residual
239
240
            residual = x.to(orig_dtype)

241
        if x.shape[-1] != hidden_size:
242
            raise ValueError(
243
                f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
244
            )
245

246
        if variance_size_override is None:
247
248
            x_var = x
        else:
249
            if hidden_size < variance_size_override:
250
251
                raise ValueError(
                    "Expected hidden_size to be at least "
252
                    f"{variance_size_override}, but found: {hidden_size}"
253
                )
254

255
            x_var = x[:, :, :variance_size_override]
256
257
258

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

259
        x = x * torch.rsqrt(variance + variance_epsilon)
260
        x = x.to(orig_dtype)
261
262
        if weight is not None:
            x = x * weight
263
264
265
266
267
        if residual is None:
            return x
        else:
            return x, residual

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

285
    def forward_cuda(
286
287
        self,
        x: torch.Tensor,
288
289
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
290
291
292
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        # Optional Oink SM100 fast path (no residual). This path is
        # torch.compile-friendly via torch.ops.oink.rmsnorm and preserves
        # 2D layouts (including padded rows) when using the Oink
        # pointer-based kernel.
        if (
            residual is None
            and getattr(self, "_use_oink_rmsnorm", False)
            and x.is_cuda
            and x.dim() >= 2
            and self.has_weight
            and not vllm_is_batch_invariant()
            and self.weight.data.dtype == x.dtype
            and self.weight.data.is_contiguous()
        ):
            orig_shape = x.shape
            hidden_size = orig_shape[-1]
            if _can_view_as_2d(x):
                x_2d = x.view(-1, hidden_size)
                if _is_oink_stride_compatible_2d(x_2d):
                    y_2d = _oink_ops.rmsnorm(
                        x_2d,
                        self.weight.data,
                        self.variance_epsilon,
                    )
                    return y_2d.view(orig_shape)

        # Optional Oink SM100 fast path (fused residual-add + RMSNorm, in-place).
        # This mirrors vLLM's fused_add_rms_norm semantics by mutating both
        # `x` (normalized output) and `residual` (residual-out buffer).
        if (
            residual is not None
            and getattr(self, "_use_oink_fused_add_rmsnorm", False)
            and x.is_cuda
            and residual.is_cuda
            and x.shape == residual.shape
            and x.dtype == residual.dtype
            and x.dim() >= 2
            and self.has_weight
            and not vllm_is_batch_invariant()
            and self.weight.data.dtype == x.dtype
            and self.weight.data.is_contiguous()
        ):
            orig_shape = x.shape
            hidden_size = orig_shape[-1]
            if _can_view_as_2d(x) and _can_view_as_2d(residual):
                x_2d = x.view(-1, hidden_size)
                res_2d = residual.view(-1, hidden_size)

                # The Oink in-place pointer path supports the common vLLM
                # layout where:
                # - `x` may be strided/padded row-major (stride(1) == 1), and
                # - `residual` is contiguous row-major ([M, N] with stride(0) == N).
                # If these conditions are not met, fall back to vLLM's built-in
                # fused kernel.
                if (
                    _is_oink_stride_compatible_2d(x_2d)
                    and _is_oink_stride_compatible_2d(res_2d)
                    and res_2d.is_contiguous()
                ):
                    _oink_ops.fused_add_rms_norm_(
                        x_2d,
                        res_2d,
                        self.weight.data,
                        self.variance_epsilon,
                    )
                    return x, residual

360
        add_residual = residual is not None
361
        if add_residual:
362
363
364
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
365
366
367
368
369
370
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)

    def forward_hip(
        self,
        x: torch.Tensor,
371
372
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
373
374
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
375

376
        add_residual = residual is not None
377
        if add_residual:
378
379
380
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
381
        else:
382
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
383

384
385
386
    def forward_xpu(
        self,
        x: torch.Tensor,
387
388
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
389
        return self.forward_cuda(x, residual)
390

391
392
393
394
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
395
396


397
# --8<-- [start:gemma_rms_norm]
398
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
399
400
401
402
403
404
405
406
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

407
408
    # --8<-- [end:gemma_rms_norm]

Woosuk Kwon's avatar
Woosuk Kwon committed
409
410
411
412
413
414
415
416
417
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

418
    @staticmethod
419
    def _forward_static_no_residual(
420
421
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
422
        x: torch.Tensor,
423
424
    ) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward() without residual."""
Woosuk Kwon's avatar
Woosuk Kwon committed
425
        orig_dtype = x.dtype
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(variance + variance_epsilon)
        x = x * (1.0 + weight.float())
        x = x.to(orig_dtype)
        return x

    @staticmethod
    def _forward_static_with_residual(
        weight: torch.Tensor,
        variance_epsilon: float,
        x: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward() with residual."""
        orig_dtype = x.dtype
        x = (
            x.float() + residual.float()
            if orig_dtype == torch.float16
            else x + residual
        )
        residual = x
Woosuk Kwon's avatar
Woosuk Kwon committed
448
449
450

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
451
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
452
453
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
454
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
455
        x = x.to(orig_dtype)
456
        return x, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
457

458
459
460
    def forward_native(
        self,
        x: torch.Tensor,
461
462
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
463
        """PyTorch-native implementation equivalent to forward()."""
464
465
466
467
468
469
470
471
        if residual is None:
            return self._forward_static_no_residual(
                self.weight.data, self.variance_epsilon, x
            )
        else:
            return self._forward_static_with_residual(
                self.weight.data, self.variance_epsilon, x, residual
            )
472

Woosuk Kwon's avatar
Woosuk Kwon committed
473
474
475
    def forward_cuda(
        self,
        x: torch.Tensor,
476
477
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
478
479
480
481
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
482
483
484
485
486
            self._forward_static_no_residual = torch.compile(  # type: ignore
                self._forward_static_no_residual
            )
            self._forward_static_with_residual = torch.compile(  # type: ignore
                self._forward_static_with_residual
487
            )
488
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
489
        return self.forward_native(x, residual)
490
491


492
# --8<-- [start:rms_norm_gated]
493
494
495
496
497
498
499
500
501
502
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
    """RMS Normalization with optional gating.

    This is a native PyTorch implementation that supports:
    - Standard RMS normalization
    - Group RMS normalization
    - Optional gating with SiLU activation
    """

503
504
    # --8<-- [end:rms_norm_gated]

505
506
507
508
509
510
511
512
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
513
        activation: str = "swish",
514
515
516
517
518
519
520
521
522
523
524
525
526
527
    ):
        """Initialize RMSNormGated.

        Args:
            hidden_size: Size of the hidden dimension
            eps: Epsilon for numerical stability
            group_size: If not None, do GroupNorm with each group
                        having group_size elements.
                        group_size=None is equivalent to group_size=hidden_size
                        (i.e. there's only 1 group).
            norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                              If False and z is provided: out = norm(x * silu(z))
            device: Device to create parameters on
            dtype: Data type for parameters
528
            activation: Activation function name for gating
529
530
531
532
        """
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.eps = eps
533
        self.activation = activation
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)

    def forward_native(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        Native PyTorch implementation of RMS normalization with gating.

        Args:
            x: Input tensor
            z: Optional gating tensor

        Returns:
            Normalized (and optionally gated) tensor

        If z is not None:
            - norm_before_gate=True: out = norm(x) * silu(z)
            - norm_before_gate=False: out = norm(x * silu(z))
        """
560
561
562
563
564
        orig_dtype = x.dtype
        x = x.float()
        weight = self.weight.float()
        z = z.float() if z is not None else None

565
566
567
568
569
570
571
572
573
        # Apply gating before normalization if needed
        if z is not None and not self.norm_before_gate:
            x = x * F.silu(z)

        # RMS Normalization
        if self.group_size is None:
            # Standard RMS norm across the last dimension
            variance = x.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x * torch.rsqrt(variance + self.eps)
574
            out = x_normed * weight
575
576
577
578
579
580
581
        else:
            # Group RMS norm
            from einops import rearrange

            x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
            variance = x_group.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x_group * torch.rsqrt(variance + self.eps)
582
            out = rearrange(x_normed, "... g d -> ... (g d)") * weight
583
584
585
586
587

        # Apply gating after normalization if needed
        if z is not None and self.norm_before_gate:
            out = out * F.silu(z)

588
        return out.to(orig_dtype)
589
590
591
592
593
594
595
596
597
598
599
600
601
602

    def forward_cuda(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
Jiangyun Zhu's avatar
Jiangyun Zhu committed
603
            activation=self.activation,
604
605
606
        )


607
608
609
610
611
612
613
614
615
616
617
618
619
class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
620
621
622
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)