"vscode:/vscode.git/clone" did not exist on "e02706d2d27c9af429adf89e7dec2b37e3ec39c1"
layernorm.py 18.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
10
11
# Import kernels
import vllm.kernels  # noqa: F401
from vllm import _oink_ops, envs, ir
12
from vllm._aiter_ops import rocm_aiter_ops
13
from vllm.logger import init_logger
14
from vllm.model_executor.custom_op import CustomOp
15
16
17
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
)
18
19
from vllm.platforms import current_platform

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
logger = init_logger(__name__)


def _can_view_as_2d(x: torch.Tensor) -> bool:
    """Return True if x.view(-1, x.shape[-1]) is viewable (no copy)."""
    if x.dim() < 2:
        return False
    if x.dim() == 2:
        return True
    # For a view(-1, N) to be valid, all leading dims must be contiguous with
    # respect to each other (size-1 dims are ignored).
    for dim in range(x.dim() - 1):
        # Strides for size-1 dims are irrelevant and can be arbitrary.
        if x.size(dim + 1) != 1 and x.stride(dim) != x.stride(dim + 1) * x.size(
            dim + 1
        ):
            return False
    return True


def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool:
    """Return True if x_2d meets Oink's pointer-path stride constraints."""
    if x_2d.dim() != 2:
        return False
    if x_2d.stride(1) != 1:
        return False
    # Match Oink's vectorization constraint: stride(0) divisible by 256b.
    if x_2d.dtype in (torch.float16, torch.bfloat16):
        divby = 16
    elif x_2d.dtype == torch.float32:
        divby = 8
    else:
        return False
    return (x_2d.stride(0) % divby) == 0

55
56

def fused_add_rms_norm(
57
58
59
60
61
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
62
    from vllm import _custom_ops as ops
63

64
    if envs.VLLM_BATCH_INVARIANT:
65
66
67
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
68
69
70
71
72
73
74
75
76
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


77
78
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
79
) -> torch.Tensor:
80
    from vllm import _custom_ops as ops
81

82
83
84
85
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
86
        weight,
87
        bias,
88
89
        variance_epsilon,
    )
90
    return out
91
92


93
def dispatch_rocm_rmsnorm_func(dtype: torch.dtype, use_aiter: bool = False):
94
    use_aiter = use_aiter and dtype in [
95
96
        torch.float16,
        torch.bfloat16,
97
98
99
    ]

    if use_aiter:
100
101
        return rocm_aiter_ops.rms_norm2d_with_add
    else:
102
        return fused_add_rms_norm
103
104


105
# --8<-- [start:rms_norm]
106
@CustomOp.register("rms_norm")
107
class RMSNorm(CustomOp):
108
109
110
111
112
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
113

114
115
    # --8<-- [end:rms_norm]

116
117
118
119
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
120
        var_hidden_size: int | None = None,
121
        has_weight: bool = True,
122
        dtype: torch.dtype | None = None,
123
124
    ) -> None:
        super().__init__()
125
126

        self.hidden_size = hidden_size
127
        self.variance_epsilon = eps
128
129
130
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
131
        weight_dtype = dtype or torch.get_default_dtype()
132
        self.has_weight = has_weight
133
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
134
135
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
136
137

        if current_platform.is_rocm():
138
            aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
139
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
140
                dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
141
            )
142

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        # Optional: enable Oink Blackwell RMSNorm custom-op fast path on
        # compatible CUDA devices (e.g., SM100) when the external Oink
        # package is available. This is detected once at construction time
        # to avoid per-call device queries in the hot path.
        self._use_oink_fused_add_rmsnorm = False
        if (
            not current_platform.is_rocm()
            and torch.cuda.is_available()
            and bool(getattr(envs, "VLLM_USE_OINK_OPS", False))
        ):
            # NOTE: vLLM disables custom ops by default when using Inductor.
            # If this op is disabled, CustomOp will dispatch to forward_native,
            # and the Oink path in forward_cuda will never run.
            if getattr(self._forward_method, "__func__", None) is getattr(
                self.forward_native, "__func__", None
            ):
                try:
                    from vllm.config import get_cached_compilation_config

                    custom_ops = get_cached_compilation_config().custom_ops
                except Exception:
                    custom_ops = ["<unknown>"]
                logger.warning_once(
                    "VLLM_USE_OINK_OPS=1 but the `rms_norm` custom op is "
                    "disabled (CompilationConfig.custom_ops=%s). Enable it via "
                    "`compilation_config={'custom_ops': ['none', '+rms_norm']}` "
                    "(or `['all']`) to let vLLM call into torch.ops.oink.*.",
                    custom_ops,
                )
                # Custom op disabled => forward_cuda won't run. Avoid doing any
                # external Oink initialization work in this case.
            else:
                try:
176
                    device_index = torch.accelerator.current_device_index()
177
178
179
180
181
182
183
184
185
186
187
188
189
190
                    if _oink_ops.is_oink_available_for_device(device_index):
                        self._use_oink_fused_add_rmsnorm = (
                            _oink_ops.has_fused_add_rms_norm()
                        )
                except Exception as e:
                    # If anything goes wrong (no Oink install, CPU-only env, etc.),
                    # silently fall back to the built-in RMSNorm path.
                    logger.warning_once(
                        "VLLM_USE_OINK_OPS=1 but failed to initialize Oink "
                        "RMSNorm; falling back to vLLM RMSNorm. Error: %s",
                        e,
                    )
                    self._use_oink_fused_add_rmsnorm = False

191
192
    @staticmethod
    def forward_static(
193
        x: torch.Tensor,
194
195
196
197
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
198
        residual: torch.Tensor | None = None,
199
        variance_size_override: int | None = None,
200
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
201
202
203
        """PyTorch-native implementation equivalent to forward()."""
        x = x.to(torch.float32)
        if residual is not None:
204
205
206
207
            # residual promoted f16->f32 automatically,
            # otherwise Inductor eliminates the casts to and from f16,
            # increasing memory usage (and complicating pattern matching)
            x = x + residual
208
209
            residual = x.to(orig_dtype)

210
        if x.shape[-1] != hidden_size:
211
            raise ValueError(
212
                f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
213
            )
214

215
        if variance_size_override is None:
216
217
            x_var = x
        else:
218
            if hidden_size < variance_size_override:
219
220
                raise ValueError(
                    "Expected hidden_size to be at least "
221
                    f"{variance_size_override}, but found: {hidden_size}"
222
                )
223

224
            x_var = x[:, :, :variance_size_override]
225
226
227

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

228
        x = x * torch.rsqrt(variance + variance_epsilon)
229
        x = x.to(orig_dtype)
230
231
        if weight is not None:
            x = x * weight
232
233
234
235
236
        if residual is None:
            return x
        else:
            return x, residual

237
238
239
240
241
242
    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""
243
        if residual is None:
244
            # TODO(luka): address the weight=None passing issue more generally
245
            return ir.ops.rms_norm(
246
247
248
249
                x,
                self.weight.data if self.has_weight else None,
                self.variance_epsilon,
                self.variance_size_override,
250
            )
251
252
253
254
255
256
257
258
259
260
261

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

262
    def forward_cuda(
263
264
        self,
        x: torch.Tensor,
265
266
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
267
268
269
270
271
        if residual is None and not envs.VLLM_BATCH_INVARIANT:
            return ir.ops.rms_norm(
                x, self.weight.data, self.variance_epsilon, self.variance_size_override
            )

272
273
274
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

275
276
277
278
279
280
281
282
283
284
285
286
        # Optional Oink SM100 fast path (fused residual-add + RMSNorm, in-place).
        # This mirrors vLLM's fused_add_rms_norm semantics by mutating both
        # `x` (normalized output) and `residual` (residual-out buffer).
        if (
            residual is not None
            and getattr(self, "_use_oink_fused_add_rmsnorm", False)
            and x.is_cuda
            and residual.is_cuda
            and x.shape == residual.shape
            and x.dtype == residual.dtype
            and x.dim() >= 2
            and self.has_weight
287
            and not envs.VLLM_BATCH_INVARIANT
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
            and self.weight.data.dtype == x.dtype
            and self.weight.data.is_contiguous()
        ):
            orig_shape = x.shape
            hidden_size = orig_shape[-1]
            if _can_view_as_2d(x) and _can_view_as_2d(residual):
                x_2d = x.view(-1, hidden_size)
                res_2d = residual.view(-1, hidden_size)

                # The Oink in-place pointer path supports the common vLLM
                # layout where:
                # - `x` may be strided/padded row-major (stride(1) == 1), and
                # - `residual` is contiguous row-major ([M, N] with stride(0) == N).
                # If these conditions are not met, fall back to vLLM's built-in
                # fused kernel.
                if (
                    _is_oink_stride_compatible_2d(x_2d)
                    and _is_oink_stride_compatible_2d(res_2d)
                    and res_2d.is_contiguous()
                ):
                    _oink_ops.fused_add_rms_norm_(
                        x_2d,
                        res_2d,
                        self.weight.data,
                        self.variance_epsilon,
                    )
                    return x, residual

316
        if residual is not None:
317
318
319
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
320
        else:
321
322
            assert envs.VLLM_BATCH_INVARIANT
            return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon)
323
324
325
326

    def forward_hip(
        self,
        x: torch.Tensor,
327
328
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
329
330
331
332
333
        if residual is None and not envs.VLLM_BATCH_INVARIANT:
            return ir.ops.rms_norm(
                x, self.weight.data, self.variance_epsilon, self.variance_size_override
            )

334
335
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
336

337
        if residual is not None:
338
339
340
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
341
        else:
342
343
            assert envs.VLLM_BATCH_INVARIANT
            return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon)
344

345
346
347
    def forward_xpu(
        self,
        x: torch.Tensor,
348
349
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
350
        return self.forward_cuda(x, residual)
351

352
353
354
355
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
356
357


358
# --8<-- [start:gemma_rms_norm]
359
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
360
361
362
363
364
365
366
367
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

368
369
    # --8<-- [end:gemma_rms_norm]

Woosuk Kwon's avatar
Woosuk Kwon committed
370
371
372
373
374
375
376
377
378
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

379
380
381
    def forward_native(
        self,
        x: torch.Tensor,
382
383
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
384
        """PyTorch-native implementation equivalent to forward()."""
385
386
387
388
389
390
391
        orig_dtype = x.dtype
        weight = self.weight.data.float() + 1.0
        if residual is not None:
            x = (
                x.float() + residual.float()
                if orig_dtype == torch.float16
                else x + residual
392
            )
393
394
395
396
397
398
            residual = x
        # ir.ops.rms_norm handles fp32 upcast internally
        out = ir.ops.rms_norm(x, weight, self.variance_epsilon)
        return (
            out.to(orig_dtype) if residual is None else (out.to(orig_dtype), residual)
        )
399

Woosuk Kwon's avatar
Woosuk Kwon committed
400
401
402
    def forward_cuda(
        self,
        x: torch.Tensor,
403
404
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
405
        return self.forward_native(x, residual)
406
407


408
# --8<-- [start:rms_norm_gated]
409
410
411
412
413
414
415
416
417
418
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
    """RMS Normalization with optional gating.

    This is a native PyTorch implementation that supports:
    - Standard RMS normalization
    - Group RMS normalization
    - Optional gating with SiLU activation
    """

419
420
    # --8<-- [end:rms_norm_gated]

421
422
423
424
425
426
427
428
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
429
        activation: str = "swish",
430
431
432
433
434
435
436
437
438
439
440
441
442
443
    ):
        """Initialize RMSNormGated.

        Args:
            hidden_size: Size of the hidden dimension
            eps: Epsilon for numerical stability
            group_size: If not None, do GroupNorm with each group
                        having group_size elements.
                        group_size=None is equivalent to group_size=hidden_size
                        (i.e. there's only 1 group).
            norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                              If False and z is provided: out = norm(x * silu(z))
            device: Device to create parameters on
            dtype: Data type for parameters
444
            activation: Activation function name for gating
445
446
447
448
        """
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.eps = eps
449
        self.activation = activation
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)

    def forward_native(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        Native PyTorch implementation of RMS normalization with gating.

        Args:
            x: Input tensor
            z: Optional gating tensor

        Returns:
            Normalized (and optionally gated) tensor

        If z is not None:
            - norm_before_gate=True: out = norm(x) * silu(z)
            - norm_before_gate=False: out = norm(x * silu(z))
        """
476
477
478
479
480
        orig_dtype = x.dtype
        x = x.float()
        weight = self.weight.float()
        z = z.float() if z is not None else None

481
482
483
484
485
486
487
488
489
        # Apply gating before normalization if needed
        if z is not None and not self.norm_before_gate:
            x = x * F.silu(z)

        # RMS Normalization
        if self.group_size is None:
            # Standard RMS norm across the last dimension
            variance = x.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x * torch.rsqrt(variance + self.eps)
490
            out = x_normed * weight
491
492
493
494
495
496
497
        else:
            # Group RMS norm
            from einops import rearrange

            x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
            variance = x_group.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x_group * torch.rsqrt(variance + self.eps)
498
            out = rearrange(x_normed, "... g d -> ... (g d)") * weight
499
500
501
502
503

        # Apply gating after normalization if needed
        if z is not None and self.norm_before_gate:
            out = out * F.silu(z)

504
        return out.to(orig_dtype)
505
506
507
508
509
510
511
512
513
514
515
516
517
518

    def forward_cuda(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
Jiangyun Zhu's avatar
Jiangyun Zhu committed
519
            activation=self.activation,
520
521
        )

522
523
524
525
526
    def forward_xpu(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        return self.forward_cuda(x, z)

527

528
529
530
531
532
533
534
535
536
537
538
539
540
class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
541
542
543
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)