"vllm/vscode:/vscode.git/clone" did not exist on "59488cc9b14e87d184dff4533348d83c45a49d02"
layernorm.py 20.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
from vllm import _oink_ops, envs
10
from vllm._aiter_ops import rocm_aiter_ops
11
from vllm.logger import init_logger
12
from vllm.model_executor.custom_op import CustomOp
13
14
15
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
)
16
17
from vllm.platforms import current_platform

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
logger = init_logger(__name__)


def _can_view_as_2d(x: torch.Tensor) -> bool:
    """Return True if x.view(-1, x.shape[-1]) is viewable (no copy)."""
    if x.dim() < 2:
        return False
    if x.dim() == 2:
        return True
    # For a view(-1, N) to be valid, all leading dims must be contiguous with
    # respect to each other (size-1 dims are ignored).
    for dim in range(x.dim() - 1):
        # Strides for size-1 dims are irrelevant and can be arbitrary.
        if x.size(dim + 1) != 1 and x.stride(dim) != x.stride(dim + 1) * x.size(
            dim + 1
        ):
            return False
    return True


def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool:
    """Return True if x_2d meets Oink's pointer-path stride constraints."""
    if x_2d.dim() != 2:
        return False
    if x_2d.stride(1) != 1:
        return False
    # Match Oink's vectorization constraint: stride(0) divisible by 256b.
    if x_2d.dtype in (torch.float16, torch.bfloat16):
        divby = 16
    elif x_2d.dtype == torch.float32:
        divby = 8
    else:
        return False
    return (x_2d.stride(0) % divby) == 0

53

54
55
56
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
57
    from vllm import _custom_ops as ops
58

59
    if envs.VLLM_BATCH_INVARIANT:
60
        return rms_norm_batch_invariant(x, weight, variance_epsilon)
61
62
63
64
65
66
67
68
69
70
71
    out = torch.empty_like(x)
    ops.rms_norm(
        out,
        x,
        weight,
        variance_epsilon,
    )
    return out


def fused_add_rms_norm(
72
73
74
75
76
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
77
    from vllm import _custom_ops as ops
78

79
    if envs.VLLM_BATCH_INVARIANT:
80
81
82
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
83
84
85
86
87
88
89
90
91
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


92
93
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
94
) -> torch.Tensor:
95
    from vllm import _custom_ops as ops
96

97
98
99
100
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
101
        weight,
102
        bias,
103
104
        variance_epsilon,
    )
105
    return out
106
107


108
109
110
111
def dispatch_rocm_rmsnorm_func(
    with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
    use_aiter = use_aiter and dtype in [
112
113
        torch.float16,
        torch.bfloat16,
114
115
116
    ]

    if use_aiter and with_fused_add:
117
        return rocm_aiter_ops.rms_norm2d_with_add
118
    if use_aiter:
119
        return rocm_aiter_ops.rms_norm
120

121
122
123
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
124
    return rms_norm
125
126


127
# --8<-- [start:rms_norm]
128
@CustomOp.register("rms_norm")
129
class RMSNorm(CustomOp):
130
131
132
133
134
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
135

136
137
    # --8<-- [end:rms_norm]

138
139
140
141
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
142
        var_hidden_size: int | None = None,
143
        has_weight: bool = True,
144
        dtype: torch.dtype | None = None,
145
146
    ) -> None:
        super().__init__()
147
148

        self.hidden_size = hidden_size
149
        self.variance_epsilon = eps
150
151
152
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
153
        weight_dtype = dtype or torch.get_default_dtype()
154
        self.has_weight = has_weight
155
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
156
157
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
158
159

        if current_platform.is_rocm():
160
            aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
161
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
162
163
164
                with_fused_add=False,
                dtype=weight_dtype,
                use_aiter=aiter_rmsnorm_enabled,
165
            )
166
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
167
                with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
168
            )
169

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        # Optional: enable Oink Blackwell RMSNorm custom-op fast path on
        # compatible CUDA devices (e.g., SM100) when the external Oink
        # package is available. This is detected once at construction time
        # to avoid per-call device queries in the hot path.
        self._use_oink_rmsnorm = False
        self._use_oink_fused_add_rmsnorm = False
        if (
            not current_platform.is_rocm()
            and torch.cuda.is_available()
            and bool(getattr(envs, "VLLM_USE_OINK_OPS", False))
        ):
            # NOTE: vLLM disables custom ops by default when using Inductor.
            # If this op is disabled, CustomOp will dispatch to forward_native,
            # and the Oink path in forward_cuda will never run.
            if getattr(self._forward_method, "__func__", None) is getattr(
                self.forward_native, "__func__", None
            ):
                try:
                    from vllm.config import get_cached_compilation_config

                    custom_ops = get_cached_compilation_config().custom_ops
                except Exception:
                    custom_ops = ["<unknown>"]
                logger.warning_once(
                    "VLLM_USE_OINK_OPS=1 but the `rms_norm` custom op is "
                    "disabled (CompilationConfig.custom_ops=%s). Enable it via "
                    "`compilation_config={'custom_ops': ['none', '+rms_norm']}` "
                    "(or `['all']`) to let vLLM call into torch.ops.oink.*.",
                    custom_ops,
                )
                # Custom op disabled => forward_cuda won't run. Avoid doing any
                # external Oink initialization work in this case.
            else:
                try:
204
                    device_index = torch.accelerator.current_device_index()
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
                    if _oink_ops.is_oink_available_for_device(device_index):
                        self._use_oink_rmsnorm = True
                        self._use_oink_fused_add_rmsnorm = (
                            _oink_ops.has_fused_add_rms_norm()
                        )
                except Exception as e:
                    # If anything goes wrong (no Oink install, CPU-only env, etc.),
                    # silently fall back to the built-in RMSNorm path.
                    logger.warning_once(
                        "VLLM_USE_OINK_OPS=1 but failed to initialize Oink "
                        "RMSNorm; falling back to vLLM RMSNorm. Error: %s",
                        e,
                    )
                    self._use_oink_rmsnorm = False
                    self._use_oink_fused_add_rmsnorm = False

221
222
    @staticmethod
    def forward_static(
223
        x: torch.Tensor,
224
225
226
227
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
228
        residual: torch.Tensor | None = None,
229
        variance_size_override: int | None = None,
230
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
231
232
233
        """PyTorch-native implementation equivalent to forward()."""
        x = x.to(torch.float32)
        if residual is not None:
234
235
236
237
            # residual promoted f16->f32 automatically,
            # otherwise Inductor eliminates the casts to and from f16,
            # increasing memory usage (and complicating pattern matching)
            x = x + residual
238
239
            residual = x.to(orig_dtype)

240
        if x.shape[-1] != hidden_size:
241
            raise ValueError(
242
                f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
243
            )
244

245
        if variance_size_override is None:
246
247
            x_var = x
        else:
248
            if hidden_size < variance_size_override:
249
250
                raise ValueError(
                    "Expected hidden_size to be at least "
251
                    f"{variance_size_override}, but found: {hidden_size}"
252
                )
253

254
            x_var = x[:, :, :variance_size_override]
255
256
257

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

258
        x = x * torch.rsqrt(variance + variance_epsilon)
259
        x = x.to(orig_dtype)
260
261
        if weight is not None:
            x = x * weight
262
263
264
265
266
        if residual is None:
            return x
        else:
            return x, residual

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

284
    def forward_cuda(
285
286
        self,
        x: torch.Tensor,
287
288
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
289
290
291
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

292
293
294
295
296
297
298
299
300
301
        # Optional Oink SM100 fast path (no residual). This path is
        # torch.compile-friendly via torch.ops.oink.rmsnorm and preserves
        # 2D layouts (including padded rows) when using the Oink
        # pointer-based kernel.
        if (
            residual is None
            and getattr(self, "_use_oink_rmsnorm", False)
            and x.is_cuda
            and x.dim() >= 2
            and self.has_weight
302
            and not envs.VLLM_BATCH_INVARIANT
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
            and self.weight.data.dtype == x.dtype
            and self.weight.data.is_contiguous()
        ):
            orig_shape = x.shape
            hidden_size = orig_shape[-1]
            if _can_view_as_2d(x):
                x_2d = x.view(-1, hidden_size)
                if _is_oink_stride_compatible_2d(x_2d):
                    y_2d = _oink_ops.rmsnorm(
                        x_2d,
                        self.weight.data,
                        self.variance_epsilon,
                    )
                    return y_2d.view(orig_shape)

        # Optional Oink SM100 fast path (fused residual-add + RMSNorm, in-place).
        # This mirrors vLLM's fused_add_rms_norm semantics by mutating both
        # `x` (normalized output) and `residual` (residual-out buffer).
        if (
            residual is not None
            and getattr(self, "_use_oink_fused_add_rmsnorm", False)
            and x.is_cuda
            and residual.is_cuda
            and x.shape == residual.shape
            and x.dtype == residual.dtype
            and x.dim() >= 2
            and self.has_weight
330
            and not envs.VLLM_BATCH_INVARIANT
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
            and self.weight.data.dtype == x.dtype
            and self.weight.data.is_contiguous()
        ):
            orig_shape = x.shape
            hidden_size = orig_shape[-1]
            if _can_view_as_2d(x) and _can_view_as_2d(residual):
                x_2d = x.view(-1, hidden_size)
                res_2d = residual.view(-1, hidden_size)

                # The Oink in-place pointer path supports the common vLLM
                # layout where:
                # - `x` may be strided/padded row-major (stride(1) == 1), and
                # - `residual` is contiguous row-major ([M, N] with stride(0) == N).
                # If these conditions are not met, fall back to vLLM's built-in
                # fused kernel.
                if (
                    _is_oink_stride_compatible_2d(x_2d)
                    and _is_oink_stride_compatible_2d(res_2d)
                    and res_2d.is_contiguous()
                ):
                    _oink_ops.fused_add_rms_norm_(
                        x_2d,
                        res_2d,
                        self.weight.data,
                        self.variance_epsilon,
                    )
                    return x, residual

359
        add_residual = residual is not None
360
        if add_residual:
361
362
363
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
364
365
366
367
368
369
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)

    def forward_hip(
        self,
        x: torch.Tensor,
370
371
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
372
373
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
374

375
        add_residual = residual is not None
376
        if add_residual:
377
378
379
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
380
        else:
381
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
382

383
384
385
    def forward_xpu(
        self,
        x: torch.Tensor,
386
387
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
388
        return self.forward_cuda(x, residual)
389

390
391
392
393
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
394
395


396
# --8<-- [start:gemma_rms_norm]
397
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
398
399
400
401
402
403
404
405
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

406
407
    # --8<-- [end:gemma_rms_norm]

Woosuk Kwon's avatar
Woosuk Kwon committed
408
409
410
411
412
413
414
415
416
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

417
    @staticmethod
418
    def _forward_static_no_residual(
419
420
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
421
        x: torch.Tensor,
422
423
    ) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward() without residual."""
Woosuk Kwon's avatar
Woosuk Kwon committed
424
        orig_dtype = x.dtype
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(variance + variance_epsilon)
        x = x * (1.0 + weight.float())
        x = x.to(orig_dtype)
        return x

    @staticmethod
    def _forward_static_with_residual(
        weight: torch.Tensor,
        variance_epsilon: float,
        x: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward() with residual."""
        orig_dtype = x.dtype
        x = (
            x.float() + residual.float()
            if orig_dtype == torch.float16
            else x + residual
        )
        residual = x
Woosuk Kwon's avatar
Woosuk Kwon committed
447
448
449

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
450
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
451
452
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
453
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
454
        x = x.to(orig_dtype)
455
        return x, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
456

457
458
459
    def forward_native(
        self,
        x: torch.Tensor,
460
461
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
462
        """PyTorch-native implementation equivalent to forward()."""
463
464
465
466
467
468
469
470
        if residual is None:
            return self._forward_static_no_residual(
                self.weight.data, self.variance_epsilon, x
            )
        else:
            return self._forward_static_with_residual(
                self.weight.data, self.variance_epsilon, x, residual
            )
471

Woosuk Kwon's avatar
Woosuk Kwon committed
472
473
474
    def forward_cuda(
        self,
        x: torch.Tensor,
475
476
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
477
478
479
480
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
481
482
483
484
485
            self._forward_static_no_residual = torch.compile(  # type: ignore
                self._forward_static_no_residual
            )
            self._forward_static_with_residual = torch.compile(  # type: ignore
                self._forward_static_with_residual
486
            )
487
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
488
        return self.forward_native(x, residual)
489
490


491
# --8<-- [start:rms_norm_gated]
492
493
494
495
496
497
498
499
500
501
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
    """RMS Normalization with optional gating.

    This is a native PyTorch implementation that supports:
    - Standard RMS normalization
    - Group RMS normalization
    - Optional gating with SiLU activation
    """

502
503
    # --8<-- [end:rms_norm_gated]

504
505
506
507
508
509
510
511
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
512
        activation: str = "swish",
513
514
515
516
517
518
519
520
521
522
523
524
525
526
    ):
        """Initialize RMSNormGated.

        Args:
            hidden_size: Size of the hidden dimension
            eps: Epsilon for numerical stability
            group_size: If not None, do GroupNorm with each group
                        having group_size elements.
                        group_size=None is equivalent to group_size=hidden_size
                        (i.e. there's only 1 group).
            norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                              If False and z is provided: out = norm(x * silu(z))
            device: Device to create parameters on
            dtype: Data type for parameters
527
            activation: Activation function name for gating
528
529
530
531
        """
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.eps = eps
532
        self.activation = activation
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)

    def forward_native(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        Native PyTorch implementation of RMS normalization with gating.

        Args:
            x: Input tensor
            z: Optional gating tensor

        Returns:
            Normalized (and optionally gated) tensor

        If z is not None:
            - norm_before_gate=True: out = norm(x) * silu(z)
            - norm_before_gate=False: out = norm(x * silu(z))
        """
559
560
561
562
563
        orig_dtype = x.dtype
        x = x.float()
        weight = self.weight.float()
        z = z.float() if z is not None else None

564
565
566
567
568
569
570
571
572
        # Apply gating before normalization if needed
        if z is not None and not self.norm_before_gate:
            x = x * F.silu(z)

        # RMS Normalization
        if self.group_size is None:
            # Standard RMS norm across the last dimension
            variance = x.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x * torch.rsqrt(variance + self.eps)
573
            out = x_normed * weight
574
575
576
577
578
579
580
        else:
            # Group RMS norm
            from einops import rearrange

            x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
            variance = x_group.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x_group * torch.rsqrt(variance + self.eps)
581
            out = rearrange(x_normed, "... g d -> ... (g d)") * weight
582
583
584
585
586

        # Apply gating after normalization if needed
        if z is not None and self.norm_before_gate:
            out = out * F.silu(z)

587
        return out.to(orig_dtype)
588
589
590
591
592
593
594
595
596
597
598
599
600
601

    def forward_cuda(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
Jiangyun Zhu's avatar
Jiangyun Zhu committed
602
            activation=self.activation,
603
604
605
        )


606
607
608
609
610
611
612
613
614
615
616
617
618
class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
619
620
621
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)