layernorm.py 18.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
10
11
# Import kernels
import vllm.kernels  # noqa: F401
from vllm import _oink_ops, envs, ir
12
from vllm._aiter_ops import rocm_aiter_ops
13
from vllm.logger import init_logger
14
from vllm.model_executor.custom_op import CustomOp
15
16
17
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
)
18
19
from vllm.platforms import current_platform

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
logger = init_logger(__name__)


def _can_view_as_2d(x: torch.Tensor) -> bool:
    """Return True if x.view(-1, x.shape[-1]) is viewable (no copy)."""
    if x.dim() < 2:
        return False
    if x.dim() == 2:
        return True
    # For a view(-1, N) to be valid, all leading dims must be contiguous with
    # respect to each other (size-1 dims are ignored).
    for dim in range(x.dim() - 1):
        # Strides for size-1 dims are irrelevant and can be arbitrary.
        if x.size(dim + 1) != 1 and x.stride(dim) != x.stride(dim + 1) * x.size(
            dim + 1
        ):
            return False
    return True


def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool:
    """Return True if x_2d meets Oink's pointer-path stride constraints."""
    if x_2d.dim() != 2:
        return False
    if x_2d.stride(1) != 1:
        return False
    # Match Oink's vectorization constraint: stride(0) divisible by 256b.
    if x_2d.dtype in (torch.float16, torch.bfloat16):
        divby = 16
    elif x_2d.dtype == torch.float32:
        divby = 8
    else:
        return False
    return (x_2d.stride(0) % divby) == 0

55
56

def fused_add_rms_norm(
57
58
59
60
61
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
62
    from vllm import _custom_ops as ops
63

64
65
66
67
68
69
70
71
72
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


73
74
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
75
) -> torch.Tensor:
76
    from vllm import _custom_ops as ops
77

78
    out = torch.empty_like(x)
79
    ops.poly_norm(  # type: ignore[attr-defined]
80
81
        out,
        x,
82
        weight,
83
        bias,
84
85
        variance_epsilon,
    )
86
    return out
87
88


89
def dispatch_rocm_rmsnorm_func(dtype: torch.dtype, use_aiter: bool = False):
90
    use_aiter = use_aiter and dtype in [
91
92
        torch.float16,
        torch.bfloat16,
93
94
95
    ]

    if use_aiter:
96
97
        return rocm_aiter_ops.rms_norm2d_with_add
    else:
98
        return fused_add_rms_norm
99
100


101
# --8<-- [start:rms_norm]
102
@CustomOp.register("rms_norm")
103
class RMSNorm(CustomOp):
104
105
106
107
108
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
109

110
111
    # --8<-- [end:rms_norm]

112
113
114
115
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
116
        var_hidden_size: int | None = None,
117
        has_weight: bool = True,
118
        dtype: torch.dtype | None = None,
119
120
    ) -> None:
        super().__init__()
121
122

        self.hidden_size = hidden_size
123
        self.variance_epsilon = eps
124
125
126
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
127
        weight_dtype = dtype or torch.get_default_dtype()
128
        self.has_weight = has_weight
129
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
130
131
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
132
133

        if current_platform.is_rocm():
134
            aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
135
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
136
                dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
137
            )
138

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        # Optional: enable Oink Blackwell RMSNorm custom-op fast path on
        # compatible CUDA devices (e.g., SM100) when the external Oink
        # package is available. This is detected once at construction time
        # to avoid per-call device queries in the hot path.
        self._use_oink_fused_add_rmsnorm = False
        if (
            not current_platform.is_rocm()
            and torch.cuda.is_available()
            and bool(getattr(envs, "VLLM_USE_OINK_OPS", False))
        ):
            # NOTE: vLLM disables custom ops by default when using Inductor.
            # If this op is disabled, CustomOp will dispatch to forward_native,
            # and the Oink path in forward_cuda will never run.
            if getattr(self._forward_method, "__func__", None) is getattr(
                self.forward_native, "__func__", None
            ):
                try:
                    from vllm.config import get_cached_compilation_config

                    custom_ops = get_cached_compilation_config().custom_ops
                except Exception:
                    custom_ops = ["<unknown>"]
                logger.warning_once(
                    "VLLM_USE_OINK_OPS=1 but the `rms_norm` custom op is "
                    "disabled (CompilationConfig.custom_ops=%s). Enable it via "
                    "`compilation_config={'custom_ops': ['none', '+rms_norm']}` "
                    "(or `['all']`) to let vLLM call into torch.ops.oink.*.",
                    custom_ops,
                )
                # Custom op disabled => forward_cuda won't run. Avoid doing any
                # external Oink initialization work in this case.
            else:
                try:
172
                    device_index = torch.accelerator.current_device_index()
173
174
175
176
177
178
179
180
181
182
183
184
185
186
                    if _oink_ops.is_oink_available_for_device(device_index):
                        self._use_oink_fused_add_rmsnorm = (
                            _oink_ops.has_fused_add_rms_norm()
                        )
                except Exception as e:
                    # If anything goes wrong (no Oink install, CPU-only env, etc.),
                    # silently fall back to the built-in RMSNorm path.
                    logger.warning_once(
                        "VLLM_USE_OINK_OPS=1 but failed to initialize Oink "
                        "RMSNorm; falling back to vLLM RMSNorm. Error: %s",
                        e,
                    )
                    self._use_oink_fused_add_rmsnorm = False

187
188
    @staticmethod
    def forward_static(
189
        x: torch.Tensor,
190
191
192
193
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
194
        residual: torch.Tensor | None = None,
195
        variance_size_override: int | None = None,
196
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
197
198
199
        """PyTorch-native implementation equivalent to forward()."""
        x = x.to(torch.float32)
        if residual is not None:
200
201
202
203
            # residual promoted f16->f32 automatically,
            # otherwise Inductor eliminates the casts to and from f16,
            # increasing memory usage (and complicating pattern matching)
            x = x + residual
204
205
            residual = x.to(orig_dtype)

206
        if x.shape[-1] != hidden_size:
207
            raise ValueError(
208
                f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
209
            )
210

211
        if variance_size_override is None:
212
213
            x_var = x
        else:
214
            if hidden_size < variance_size_override:
215
216
                raise ValueError(
                    "Expected hidden_size to be at least "
217
                    f"{variance_size_override}, but found: {hidden_size}"
218
                )
219

220
            x_var = x[:, :, :variance_size_override]
221
222
223

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

224
        x = x * torch.rsqrt(variance + variance_epsilon)
225
        x = x.to(orig_dtype)
226
227
        if weight is not None:
            x = x * weight
228
229
230
231
232
        if residual is None:
            return x
        else:
            return x, residual

233
234
235
236
237
238
    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""
239
        if residual is None:
240
            # TODO(luka): address the weight=None passing issue more generally
241
            return ir.ops.rms_norm(
242
243
244
245
                x,
                self.weight.data if self.has_weight else None,
                self.variance_epsilon,
                self.variance_size_override,
246
            )
247
248
249
250
251
252
253
254
255
256
257

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

258
    def forward_cuda(
259
260
        self,
        x: torch.Tensor,
261
262
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
263
264
265
266
267
        if residual is None and not envs.VLLM_BATCH_INVARIANT:
            return ir.ops.rms_norm(
                x, self.weight.data, self.variance_epsilon, self.variance_size_override
            )

268
269
270
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

271
272
273
274
275
276
277
278
279
280
281
282
        # Optional Oink SM100 fast path (fused residual-add + RMSNorm, in-place).
        # This mirrors vLLM's fused_add_rms_norm semantics by mutating both
        # `x` (normalized output) and `residual` (residual-out buffer).
        if (
            residual is not None
            and getattr(self, "_use_oink_fused_add_rmsnorm", False)
            and x.is_cuda
            and residual.is_cuda
            and x.shape == residual.shape
            and x.dtype == residual.dtype
            and x.dim() >= 2
            and self.has_weight
283
            and not envs.VLLM_BATCH_INVARIANT
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
            and self.weight.data.dtype == x.dtype
            and self.weight.data.is_contiguous()
        ):
            orig_shape = x.shape
            hidden_size = orig_shape[-1]
            if _can_view_as_2d(x) and _can_view_as_2d(residual):
                x_2d = x.view(-1, hidden_size)
                res_2d = residual.view(-1, hidden_size)

                # The Oink in-place pointer path supports the common vLLM
                # layout where:
                # - `x` may be strided/padded row-major (stride(1) == 1), and
                # - `residual` is contiguous row-major ([M, N] with stride(0) == N).
                # If these conditions are not met, fall back to vLLM's built-in
                # fused kernel.
                if (
                    _is_oink_stride_compatible_2d(x_2d)
                    and _is_oink_stride_compatible_2d(res_2d)
                    and res_2d.is_contiguous()
                ):
                    _oink_ops.fused_add_rms_norm_(
                        x_2d,
                        res_2d,
                        self.weight.data,
                        self.variance_epsilon,
                    )
                    return x, residual

312
        if residual is not None:
313
314
315
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
316
        else:
317
318
            assert envs.VLLM_BATCH_INVARIANT
            return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon)
319
320
321
322

    def forward_hip(
        self,
        x: torch.Tensor,
323
324
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
325
326
327
328
329
        if residual is None and not envs.VLLM_BATCH_INVARIANT:
            return ir.ops.rms_norm(
                x, self.weight.data, self.variance_epsilon, self.variance_size_override
            )

330
331
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
332

333
        if residual is not None:
334
335
336
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
337
        else:
338
339
            assert envs.VLLM_BATCH_INVARIANT
            return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon)
340

341
342
343
    def forward_xpu(
        self,
        x: torch.Tensor,
344
345
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
346
        return self.forward_cuda(x, residual)
347

348
349
350
351
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
352
353


354
# --8<-- [start:gemma_rms_norm]
355
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
356
357
358
359
360
361
362
363
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

364
365
    # --8<-- [end:gemma_rms_norm]

Woosuk Kwon's avatar
Woosuk Kwon committed
366
367
368
369
370
371
372
373
374
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

375
376
377
    def forward_native(
        self,
        x: torch.Tensor,
378
379
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
380
        """PyTorch-native implementation equivalent to forward()."""
381
382
383
384
385
386
387
        orig_dtype = x.dtype
        weight = self.weight.data.float() + 1.0
        if residual is not None:
            x = (
                x.float() + residual.float()
                if orig_dtype == torch.float16
                else x + residual
388
            )
389
390
391
392
393
394
            residual = x
        # ir.ops.rms_norm handles fp32 upcast internally
        out = ir.ops.rms_norm(x, weight, self.variance_epsilon)
        return (
            out.to(orig_dtype) if residual is None else (out.to(orig_dtype), residual)
        )
395

Woosuk Kwon's avatar
Woosuk Kwon committed
396
397
398
    def forward_cuda(
        self,
        x: torch.Tensor,
399
400
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
401
        return self.forward_native(x, residual)
402
403


404
# --8<-- [start:rms_norm_gated]
405
406
407
408
409
410
411
412
413
414
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
    """RMS Normalization with optional gating.

    This is a native PyTorch implementation that supports:
    - Standard RMS normalization
    - Group RMS normalization
    - Optional gating with SiLU activation
    """

415
416
    # --8<-- [end:rms_norm_gated]

417
418
419
420
421
422
423
424
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
425
        activation: str = "swish",
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    ):
        """Initialize RMSNormGated.

        Args:
            hidden_size: Size of the hidden dimension
            eps: Epsilon for numerical stability
            group_size: If not None, do GroupNorm with each group
                        having group_size elements.
                        group_size=None is equivalent to group_size=hidden_size
                        (i.e. there's only 1 group).
            norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                              If False and z is provided: out = norm(x * silu(z))
            device: Device to create parameters on
            dtype: Data type for parameters
440
            activation: Activation function name for gating
441
442
443
444
        """
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.eps = eps
445
        self.activation = activation
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)

    def forward_native(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        Native PyTorch implementation of RMS normalization with gating.

        Args:
            x: Input tensor
            z: Optional gating tensor

        Returns:
            Normalized (and optionally gated) tensor

        If z is not None:
            - norm_before_gate=True: out = norm(x) * silu(z)
            - norm_before_gate=False: out = norm(x * silu(z))
        """
472
473
474
475
476
        orig_dtype = x.dtype
        x = x.float()
        weight = self.weight.float()
        z = z.float() if z is not None else None

477
478
479
        assert self.activation in ["silu", "sigmoid", "swish"]
        act_fn = F.sigmoid if self.activation == "sigmoid" else F.silu

480
481
        # Apply gating before normalization if needed
        if z is not None and not self.norm_before_gate:
482
            x = x * act_fn(z)
483
484
485
486
487
488

        # RMS Normalization
        if self.group_size is None:
            # Standard RMS norm across the last dimension
            variance = x.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x * torch.rsqrt(variance + self.eps)
489
            out = x_normed * weight
490
491
492
493
494
495
496
        else:
            # Group RMS norm
            from einops import rearrange

            x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
            variance = x_group.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x_group * torch.rsqrt(variance + self.eps)
497
            out = rearrange(x_normed, "... g d -> ... (g d)") * weight
498
499
500

        # Apply gating after normalization if needed
        if z is not None and self.norm_before_gate:
501
            out = out * act_fn(z)
502

503
        return out.to(orig_dtype)
504
505
506
507
508
509
510
511
512
513
514
515
516
517

    def forward_cuda(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
Jiangyun Zhu's avatar
Jiangyun Zhu committed
518
            activation=self.activation,
519
520
        )

521
522
523
524
525
    def forward_xpu(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        return self.forward_cuda(x, z)

526

527
528
529
530
531
532
533
534
535
536
537
538
539
class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
540
541
542
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)