"vllm/vscode:/vscode.git/clone" did not exist on "09540cd918a5f7d776d7f7e0abec78fbc03938ad"
layernorm.py 19.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
10
11
# Import kernels
import vllm.kernels  # noqa: F401
from vllm import _oink_ops, envs, ir
12
from vllm._aiter_ops import rocm_aiter_ops
13
from vllm.logger import init_logger
14
from vllm.model_executor.custom_op import CustomOp
15
16
17
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
)
18
19
from vllm.platforms import current_platform

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
logger = init_logger(__name__)


def _can_view_as_2d(x: torch.Tensor) -> bool:
    """Return True if x.view(-1, x.shape[-1]) is viewable (no copy)."""
    if x.dim() < 2:
        return False
    if x.dim() == 2:
        return True
    # For a view(-1, N) to be valid, all leading dims must be contiguous with
    # respect to each other (size-1 dims are ignored).
    for dim in range(x.dim() - 1):
        # Strides for size-1 dims are irrelevant and can be arbitrary.
        if x.size(dim + 1) != 1 and x.stride(dim) != x.stride(dim + 1) * x.size(
            dim + 1
        ):
            return False
    return True


def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool:
    """Return True if x_2d meets Oink's pointer-path stride constraints."""
    if x_2d.dim() != 2:
        return False
    if x_2d.stride(1) != 1:
        return False
    # Match Oink's vectorization constraint: stride(0) divisible by 256b.
    if x_2d.dtype in (torch.float16, torch.bfloat16):
        divby = 16
    elif x_2d.dtype == torch.float32:
        divby = 8
    else:
        return False
    return (x_2d.stride(0) % divby) == 0

55
56

def fused_add_rms_norm(
57
58
59
60
61
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
62
    from vllm import _custom_ops as ops
63

64
    if envs.VLLM_BATCH_INVARIANT:
65
66
67
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
68
69
70
71
72
73
74
75
76
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


77
78
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
79
) -> torch.Tensor:
80
    from vllm import _custom_ops as ops
81

82
83
84
85
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
86
        weight,
87
        bias,
88
89
        variance_epsilon,
    )
90
    return out
91
92


93
def dispatch_rocm_rmsnorm_func(dtype: torch.dtype, use_aiter: bool = False):
94
    use_aiter = use_aiter and dtype in [
95
96
        torch.float16,
        torch.bfloat16,
97
98
99
    ]

    if use_aiter:
100
101
        return rocm_aiter_ops.rms_norm2d_with_add
    else:
102
        return fused_add_rms_norm
103
104


105
# --8<-- [start:rms_norm]
106
@CustomOp.register("rms_norm")
107
class RMSNorm(CustomOp):
108
109
110
111
112
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
113

114
115
    # --8<-- [end:rms_norm]

116
117
118
119
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
120
        var_hidden_size: int | None = None,
121
        has_weight: bool = True,
122
        dtype: torch.dtype | None = None,
123
124
    ) -> None:
        super().__init__()
125
126

        self.hidden_size = hidden_size
127
        self.variance_epsilon = eps
128
129
130
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
131
        weight_dtype = dtype or torch.get_default_dtype()
132
        self.has_weight = has_weight
133
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
134
135
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
136
137

        if current_platform.is_rocm():
138
            aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
139
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
140
                dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
141
            )
142

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        # Optional: enable Oink Blackwell RMSNorm custom-op fast path on
        # compatible CUDA devices (e.g., SM100) when the external Oink
        # package is available. This is detected once at construction time
        # to avoid per-call device queries in the hot path.
        self._use_oink_fused_add_rmsnorm = False
        if (
            not current_platform.is_rocm()
            and torch.cuda.is_available()
            and bool(getattr(envs, "VLLM_USE_OINK_OPS", False))
        ):
            # NOTE: vLLM disables custom ops by default when using Inductor.
            # If this op is disabled, CustomOp will dispatch to forward_native,
            # and the Oink path in forward_cuda will never run.
            if getattr(self._forward_method, "__func__", None) is getattr(
                self.forward_native, "__func__", None
            ):
                try:
                    from vllm.config import get_cached_compilation_config

                    custom_ops = get_cached_compilation_config().custom_ops
                except Exception:
                    custom_ops = ["<unknown>"]
                logger.warning_once(
                    "VLLM_USE_OINK_OPS=1 but the `rms_norm` custom op is "
                    "disabled (CompilationConfig.custom_ops=%s). Enable it via "
                    "`compilation_config={'custom_ops': ['none', '+rms_norm']}` "
                    "(or `['all']`) to let vLLM call into torch.ops.oink.*.",
                    custom_ops,
                )
                # Custom op disabled => forward_cuda won't run. Avoid doing any
                # external Oink initialization work in this case.
            else:
                try:
176
                    device_index = torch.accelerator.current_device_index()
177
178
179
180
181
182
183
184
185
186
187
188
189
190
                    if _oink_ops.is_oink_available_for_device(device_index):
                        self._use_oink_fused_add_rmsnorm = (
                            _oink_ops.has_fused_add_rms_norm()
                        )
                except Exception as e:
                    # If anything goes wrong (no Oink install, CPU-only env, etc.),
                    # silently fall back to the built-in RMSNorm path.
                    logger.warning_once(
                        "VLLM_USE_OINK_OPS=1 but failed to initialize Oink "
                        "RMSNorm; falling back to vLLM RMSNorm. Error: %s",
                        e,
                    )
                    self._use_oink_fused_add_rmsnorm = False

191
192
    @staticmethod
    def forward_static(
193
        x: torch.Tensor,
194
195
196
197
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
198
        residual: torch.Tensor | None = None,
199
        variance_size_override: int | None = None,
200
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
201
202
203
        """PyTorch-native implementation equivalent to forward()."""
        x = x.to(torch.float32)
        if residual is not None:
204
205
206
207
            # residual promoted f16->f32 automatically,
            # otherwise Inductor eliminates the casts to and from f16,
            # increasing memory usage (and complicating pattern matching)
            x = x + residual
208
209
            residual = x.to(orig_dtype)

210
        if x.shape[-1] != hidden_size:
211
            raise ValueError(
212
                f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
213
            )
214

215
        if variance_size_override is None:
216
217
            x_var = x
        else:
218
            if hidden_size < variance_size_override:
219
220
                raise ValueError(
                    "Expected hidden_size to be at least "
221
                    f"{variance_size_override}, but found: {hidden_size}"
222
                )
223

224
            x_var = x[:, :, :variance_size_override]
225
226
227

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

228
        x = x * torch.rsqrt(variance + variance_epsilon)
229
        x = x.to(orig_dtype)
230
231
        if weight is not None:
            x = x * weight
232
233
234
235
236
        if residual is None:
            return x
        else:
            return x, residual

237
238
239
240
241
242
    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""
243
244
245
246
        if residual is None:
            return ir.ops.rms_norm(
                x, self.weight.data, self.variance_epsilon, self.variance_size_override
            )
247
248
249
250
251
252
253
254
255
256
257

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

258
    def forward_cuda(
259
260
        self,
        x: torch.Tensor,
261
262
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
263
264
265
266
267
        if residual is None and not envs.VLLM_BATCH_INVARIANT:
            return ir.ops.rms_norm(
                x, self.weight.data, self.variance_epsilon, self.variance_size_override
            )

268
269
270
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

271
272
273
274
275
276
277
278
279
280
281
282
        # Optional Oink SM100 fast path (fused residual-add + RMSNorm, in-place).
        # This mirrors vLLM's fused_add_rms_norm semantics by mutating both
        # `x` (normalized output) and `residual` (residual-out buffer).
        if (
            residual is not None
            and getattr(self, "_use_oink_fused_add_rmsnorm", False)
            and x.is_cuda
            and residual.is_cuda
            and x.shape == residual.shape
            and x.dtype == residual.dtype
            and x.dim() >= 2
            and self.has_weight
283
            and not envs.VLLM_BATCH_INVARIANT
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
            and self.weight.data.dtype == x.dtype
            and self.weight.data.is_contiguous()
        ):
            orig_shape = x.shape
            hidden_size = orig_shape[-1]
            if _can_view_as_2d(x) and _can_view_as_2d(residual):
                x_2d = x.view(-1, hidden_size)
                res_2d = residual.view(-1, hidden_size)

                # The Oink in-place pointer path supports the common vLLM
                # layout where:
                # - `x` may be strided/padded row-major (stride(1) == 1), and
                # - `residual` is contiguous row-major ([M, N] with stride(0) == N).
                # If these conditions are not met, fall back to vLLM's built-in
                # fused kernel.
                if (
                    _is_oink_stride_compatible_2d(x_2d)
                    and _is_oink_stride_compatible_2d(res_2d)
                    and res_2d.is_contiguous()
                ):
                    _oink_ops.fused_add_rms_norm_(
                        x_2d,
                        res_2d,
                        self.weight.data,
                        self.variance_epsilon,
                    )
                    return x, residual

312
        if residual is not None:
313
314
315
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
316
        else:
317
318
            assert envs.VLLM_BATCH_INVARIANT
            return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon)
319
320
321
322

    def forward_hip(
        self,
        x: torch.Tensor,
323
324
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
325
326
327
328
329
        if residual is None and not envs.VLLM_BATCH_INVARIANT:
            return ir.ops.rms_norm(
                x, self.weight.data, self.variance_epsilon, self.variance_size_override
            )

330
331
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
332

333
        if residual is not None:
334
335
336
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
337
        else:
338
339
            assert envs.VLLM_BATCH_INVARIANT
            return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon)
340

341
342
343
    def forward_xpu(
        self,
        x: torch.Tensor,
344
345
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
346
        return self.forward_cuda(x, residual)
347

348
349
350
351
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
352
353


354
# --8<-- [start:gemma_rms_norm]
355
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
356
357
358
359
360
361
362
363
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

364
365
    # --8<-- [end:gemma_rms_norm]

Woosuk Kwon's avatar
Woosuk Kwon committed
366
367
368
369
370
371
372
373
374
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

375
    @staticmethod
376
    def _forward_static_no_residual(
377
378
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
379
        x: torch.Tensor,
380
381
    ) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward() without residual."""
Woosuk Kwon's avatar
Woosuk Kwon committed
382
        orig_dtype = x.dtype
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(variance + variance_epsilon)
        x = x * (1.0 + weight.float())
        x = x.to(orig_dtype)
        return x

    @staticmethod
    def _forward_static_with_residual(
        weight: torch.Tensor,
        variance_epsilon: float,
        x: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward() with residual."""
        orig_dtype = x.dtype
        x = (
            x.float() + residual.float()
            if orig_dtype == torch.float16
            else x + residual
        )
        residual = x
Woosuk Kwon's avatar
Woosuk Kwon committed
405
406
407

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
408
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
409
410
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
411
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
412
        x = x.to(orig_dtype)
413
        return x, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
414

415
416
417
    def forward_native(
        self,
        x: torch.Tensor,
418
419
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
420
        """PyTorch-native implementation equivalent to forward()."""
421
422
423
424
425
426
427
428
        if residual is None:
            return self._forward_static_no_residual(
                self.weight.data, self.variance_epsilon, x
            )
        else:
            return self._forward_static_with_residual(
                self.weight.data, self.variance_epsilon, x, residual
            )
429

Woosuk Kwon's avatar
Woosuk Kwon committed
430
431
432
    def forward_cuda(
        self,
        x: torch.Tensor,
433
434
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
435
436
437
438
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
439
440
441
442
443
            self._forward_static_no_residual = torch.compile(  # type: ignore
                self._forward_static_no_residual
            )
            self._forward_static_with_residual = torch.compile(  # type: ignore
                self._forward_static_with_residual
444
            )
445
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
446
        return self.forward_native(x, residual)
447
448


449
# --8<-- [start:rms_norm_gated]
450
451
452
453
454
455
456
457
458
459
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
    """RMS Normalization with optional gating.

    This is a native PyTorch implementation that supports:
    - Standard RMS normalization
    - Group RMS normalization
    - Optional gating with SiLU activation
    """

460
461
    # --8<-- [end:rms_norm_gated]

462
463
464
465
466
467
468
469
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
470
        activation: str = "swish",
471
472
473
474
475
476
477
478
479
480
481
482
483
484
    ):
        """Initialize RMSNormGated.

        Args:
            hidden_size: Size of the hidden dimension
            eps: Epsilon for numerical stability
            group_size: If not None, do GroupNorm with each group
                        having group_size elements.
                        group_size=None is equivalent to group_size=hidden_size
                        (i.e. there's only 1 group).
            norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                              If False and z is provided: out = norm(x * silu(z))
            device: Device to create parameters on
            dtype: Data type for parameters
485
            activation: Activation function name for gating
486
487
488
489
        """
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.eps = eps
490
        self.activation = activation
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)

    def forward_native(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        Native PyTorch implementation of RMS normalization with gating.

        Args:
            x: Input tensor
            z: Optional gating tensor

        Returns:
            Normalized (and optionally gated) tensor

        If z is not None:
            - norm_before_gate=True: out = norm(x) * silu(z)
            - norm_before_gate=False: out = norm(x * silu(z))
        """
517
518
519
520
521
        orig_dtype = x.dtype
        x = x.float()
        weight = self.weight.float()
        z = z.float() if z is not None else None

522
523
524
525
526
527
528
529
530
        # Apply gating before normalization if needed
        if z is not None and not self.norm_before_gate:
            x = x * F.silu(z)

        # RMS Normalization
        if self.group_size is None:
            # Standard RMS norm across the last dimension
            variance = x.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x * torch.rsqrt(variance + self.eps)
531
            out = x_normed * weight
532
533
534
535
536
537
538
        else:
            # Group RMS norm
            from einops import rearrange

            x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
            variance = x_group.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x_group * torch.rsqrt(variance + self.eps)
539
            out = rearrange(x_normed, "... g d -> ... (g d)") * weight
540
541
542
543
544

        # Apply gating after normalization if needed
        if z is not None and self.norm_before_gate:
            out = out * F.silu(z)

545
        return out.to(orig_dtype)
546
547
548
549
550
551
552
553
554
555
556
557
558
559

    def forward_cuda(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
Jiangyun Zhu's avatar
Jiangyun Zhu committed
560
            activation=self.activation,
561
562
        )

563
564
565
566
567
    def forward_xpu(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        return self.forward_cuda(x, z)

568

569
570
571
572
573
574
575
576
577
578
579
580
581
class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
582
583
584
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)