layernorm.py 19.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
from typing import Optional
10
from vllm._aiter_ops import rocm_aiter_ops
11
from vllm.model_executor.custom_op import CustomOp
zhuwenwen's avatar
zhuwenwen committed
12

13
14
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
15
    vllm_is_batch_invariant,
16
)
17
from vllm.platforms import current_platform
18
from vllm.utils import direct_register_custom_op
19
from vllm import envs
20
21


22
23
24
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
25
    from vllm import _custom_ops as ops
26

27
    if vllm_is_batch_invariant():
28
        return rms_norm_batch_invariant(x, weight, variance_epsilon)
29
    out = torch.empty_like(x)
30
31
    # if envs.VLLM_USE_OPT_OP:
    if False:
zhuwenwen's avatar
zhuwenwen committed
32
        ops.rms_norm_opt(
zhuwenwen's avatar
zhuwenwen committed
33
34
            x,
            weight,
35
            out,
zhuwenwen's avatar
zhuwenwen committed
36
37
38
39
40
41
42
43
44
            variance_epsilon,
        )
    else:
        ops.rms_norm(
            out,
            x,
            weight,
            variance_epsilon,
        )
45
46
47
48
    return out


def fused_add_rms_norm(
49
50
51
52
53
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
54
    from vllm import _custom_ops as ops
55

56
    if vllm_is_batch_invariant():
57
58
59
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
60
61
    # if envs.VLLM_USE_OPT_OP:
    if False:
zhuwenwen's avatar
zhuwenwen committed
62
        ops.fused_add_rms_norm_opt(
zhuwenwen's avatar
zhuwenwen committed
63
64
65
66
67
68
69
70
71
72
73
74
            x,
            residual,
            weight,
            variance_epsilon,
        )
    else:
        ops.fused_add_rms_norm(
            x,
            residual,
            weight,
            variance_epsilon,
        )
75
76
77
    return x, residual


78
79
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
80
) -> torch.Tensor:
81
    from vllm import _custom_ops as ops
82

83
84
85
86
87
88
89
90
91
92
93
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
        weight,
        bias,
        variance_epsilon,
    )
    return out


94
95
96
97
def dispatch_rocm_rmsnorm_func(
    with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
    use_aiter = use_aiter and dtype in [
98
99
        torch.float16,
        torch.bfloat16,
100
101
102
    ]

    if use_aiter and with_fused_add:
103
        return rocm_aiter_ops.rms_norm2d_with_add
104
    if use_aiter:
105
        return rocm_aiter_ops.rms_norm
106

107
108
109
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
110
    return rms_norm
111
112


113
# --8<-- [start:rms_norm]
114
@CustomOp.register("rms_norm")
115
class RMSNorm(CustomOp):
116
117
118
119
120
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
121

122
123
    # --8<-- [end:rms_norm]

124
125
126
127
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
128
        var_hidden_size: int | None = None,
129
        has_weight: bool = True,
130
        dtype: torch.dtype | None = None,
131
132
    ) -> None:
        super().__init__()
133
134

        self.hidden_size = hidden_size
135
        self.variance_epsilon = eps
136
137
138
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
139
        weight_dtype = dtype or torch.get_default_dtype()
140
        self.has_weight = has_weight
141
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
142
143
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
144
145

        if current_platform.is_rocm():
146
            aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
147
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
148
149
150
                with_fused_add=False,
                dtype=weight_dtype,
                use_aiter=aiter_rmsnorm_enabled,
151
            )
152
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
153
                with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
154
            )
155

156
157
    @staticmethod
    def forward_static(
zhuwenwen's avatar
zhuwenwen committed
158
        # self,
159
        x: torch.Tensor,
160
161
162
163
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
164
        residual: torch.Tensor | None = None,
165
        variance_size_override: int | None = None,
166
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
167
        """PyTorch-native implementation equivalent to forward()."""
zhuwenwen's avatar
zhuwenwen committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        # if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
        #     return self.forward_cuda(x, residual)  
        # else:
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            # residual promoted f16->f32 automatically,
            # otherwise Inductor eliminates the casts to and from f16,
            # increasing memory usage (and complicating pattern matching)
            x = x + residual
            residual = x.to(orig_dtype)

        if x.shape[-1] != hidden_size:
            raise ValueError(
                f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
            )

        if variance_size_override is None:
            x_var = x
187
        else:
zhuwenwen's avatar
zhuwenwen committed
188
            if hidden_size < variance_size_override:
189
                raise ValueError(
zhuwenwen's avatar
zhuwenwen committed
190
191
                    "Expected hidden_size to be at least "
                    f"{variance_size_override}, but found: {hidden_size}"
192
                )
zhuwenwen's avatar
zhuwenwen committed
193

zhuwenwen's avatar
zhuwenwen committed
194
            x_var = x[:, :, :variance_size_override]
zhuwenwen's avatar
zhuwenwen committed
195

zhuwenwen's avatar
zhuwenwen committed
196
        variance = x_var.pow(2).mean(dim=-1, keepdim=True)
zhuwenwen's avatar
zhuwenwen committed
197

zhuwenwen's avatar
zhuwenwen committed
198
199
200
201
202
203
204
205
        x = x * torch.rsqrt(variance + variance_epsilon)
        x = x.to(orig_dtype)
        if weight is not None:
            x = x * weight
        if residual is None:
            return x
        else:
            return x, residual
206

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

224
    def forward_cuda(
225
226
        self,
        x: torch.Tensor,
227
228
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
229
230
231
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

232
        add_residual = residual is not None
233
        if add_residual:
234
235
236
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
237
238
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)
239

240
241
242
    def forward_hip(
        self,
        x: torch.Tensor,
243
244
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
245
246
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
247

248
        add_residual = residual is not None
249
        if add_residual:
250
251
252
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
zhuwenwen's avatar
zhuwenwen committed
253
        else:
254
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
255
           
zhuwenwen's avatar
zhuwenwen committed
256
257
258
    def forward_apex(
        self,
        x: torch.Tensor,
zhuwenwen's avatar
zhuwenwen committed
259
260
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
261
262
263
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
        
zhuwenwen's avatar
zhuwenwen committed
264
265
266
267
        from apex.normalization.fused_layer_norm import fused_rms_norm_affine
        add_residual = residual is not None

        if add_residual:
zhuwenwen's avatar
zhuwenwen committed
268
269
            return self.rocm_norm_func_with_add(x, residual, self.weight.data,
                                                self.variance_epsilon)
zhuwenwen's avatar
zhuwenwen committed
270
271
        else:
            return fused_rms_norm_affine(x, self.weight.data, torch.Size((x.shape[-1],)), self.variance_epsilon)
272

273
274
275
    def forward_xpu(
        self,
        x: torch.Tensor,
276
277
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
278
279
280
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

281
282
283
284
285
286
287
288
289
290
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
291
        return ops.rms_norm(
292
293
294
295
296
            x,
            self.weight.data,
            self.variance_epsilon,
        )

297
298
299
300
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
    

class FusedRMSNormQuant(nn.Module):
    """Fuse Root mean square normalization and int8 quant.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
        var_hidden_size: int | None = None,
        has_weight: bool = True,
        dtype: torch.dtype | None = None,
    ) -> None:
        super().__init__()

        self.hidden_size = hidden_size
        self.variance_epsilon = eps
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
        weight_dtype = dtype or torch.get_default_dtype()
        self.has_weight = has_weight
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
    
    def forward(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
        quant_dtype: torch.dtype = torch.int8,
        update_input: Optional[bool] = True
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
wujl5's avatar
wujl5 committed
338
        i_q, i_s = torch.ops.vllm.fused_rmsquant_customer_impl(input=x,
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
                                                 weight=self.weight,
                                                 epsilon=self.variance_epsilon,
                                                 quant_dtype=quant_dtype,
                                                 residual=residual,
                                                 update_input=update_input)

        return i_q, i_s, residual
    

def fused_rmsquant_impl(
    input: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
    residual: Optional[torch.Tensor] = None,
    update_input: Optional[bool] = True
) -> tuple[torch.Tensor, torch.Tensor]:
    output = torch.empty_like(input, device=input.device, dtype=quant_dtype)
    scales = torch.empty((input.numel() // input.shape[-1], 1),
                        device=input.device,
                        dtype=torch.float32)
    
    from lightop.op import rms_norm_dynamic_per_token_quant as ligtop_rms_norm_dynamic_per_token_quant
    ligtop_rms_norm_dynamic_per_token_quant(output, input, weight,
                                               scales, epsilon,
                                               residual, update_input)
    return output, scales

def fused_rmsquant_fake(
    input: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
    residual: Optional[torch.Tensor] = None,
    update_input: Optional[bool] = True
) -> tuple[torch.Tensor, torch.Tensor]:
    """Fake implementation for torch.compile"""
    output = torch.empty_like(input, dtype=quant_dtype)
    scales = torch.empty((input.numel() // input.shape[-1], 1),
                        device=input.device,
                        dtype=torch.float32)
    return output, scales

# from torch.library import Library
# customer_lib = Library("customer_", "FRAGMENT")

direct_register_custom_op(
wujl5's avatar
wujl5 committed
386
    op_name="fused_rmsquant_customer_impl",
387
    op_func=fused_rmsquant_impl,
wujl5's avatar
wujl5 committed
388
    mutates_args=["input", "residual"],
389
390
    fake_impl=fused_rmsquant_fake,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
391
392


393
# --8<-- [start:gemma_rms_norm]
394
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
395
396
397
398
399
400
401
402
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

403
404
    # --8<-- [end:gemma_rms_norm]

Woosuk Kwon's avatar
Woosuk Kwon committed
405
406
407
408
409
410
411
412
413
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

414
    @staticmethod
415
    def _forward_static_no_residual(
416
417
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
418
        x: torch.Tensor,
419
420
    ) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward() without residual."""
Woosuk Kwon's avatar
Woosuk Kwon committed
421
        orig_dtype = x.dtype
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(variance + variance_epsilon)
        x = x * (1.0 + weight.float())
        x = x.to(orig_dtype)
        return x

    @staticmethod
    def _forward_static_with_residual(
        weight: torch.Tensor,
        variance_epsilon: float,
        x: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward() with residual."""
        orig_dtype = x.dtype
        x = (
            x.float() + residual.float()
            if orig_dtype == torch.float16
            else x + residual
        )
        residual = x
Woosuk Kwon's avatar
Woosuk Kwon committed
444
445
446

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
447
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
448
449
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
450
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
451
        x = x.to(orig_dtype)
452
        return x, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
453

454
455
456
    def forward_native(
        self,
        x: torch.Tensor,
457
458
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
459
        """PyTorch-native implementation equivalent to forward()."""
460
461
462
463
464
465
466
467
        if residual is None:
            return self._forward_static_no_residual(
                self.weight.data, self.variance_epsilon, x
            )
        else:
            return self._forward_static_with_residual(
                self.weight.data, self.variance_epsilon, x, residual
            )
468

Woosuk Kwon's avatar
Woosuk Kwon committed
469
470
471
    def forward_cuda(
        self,
        x: torch.Tensor,
472
473
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
474
475
476
477
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
478
479
480
481
482
            self._forward_static_no_residual = torch.compile(  # type: ignore
                self._forward_static_no_residual
            )
            self._forward_static_with_residual = torch.compile(  # type: ignore
                self._forward_static_with_residual
483
            )
484
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
485
        return self.forward_native(x, residual)
486
487


488
# --8<-- [start:rms_norm_gated]
489
490
491
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
    """RMS Normalization with optional gating.
492

493
494
495
496
    This is a native PyTorch implementation that supports:
    - Standard RMS normalization
    - Group RMS normalization
    - Optional gating with SiLU activation
497
498
    """

499
500
    # --8<-- [end:rms_norm_gated]

501
502
    def __init__(
        self,
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
        hidden_size: int,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        """Initialize RMSNormGated.

        Args:
            hidden_size: Size of the hidden dimension
            eps: Epsilon for numerical stability
            group_size: If not None, do GroupNorm with each group
                        having group_size elements.
                        group_size=None is equivalent to group_size=hidden_size
                        (i.e. there's only 1 group).
            norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                              If False and z is provided: out = norm(x * silu(z))
            device: Device to create parameters on
            dtype: Data type for parameters
        """
        factory_kwargs = {"device": device, "dtype": dtype}
525
        super().__init__()
526
527
528
529
530
531
        self.eps = eps
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()
532

533
534
    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)
535
536

    def forward_native(
537
        self, x: torch.Tensor, z: torch.Tensor | None = None
538
    ) -> torch.Tensor:
539
540
541
542
543
544
        """
        Native PyTorch implementation of RMS normalization with gating.

        Args:
            x: Input tensor
            z: Optional gating tensor
545

546
547
548
549
550
551
        Returns:
            Normalized (and optionally gated) tensor

        If z is not None:
            - norm_before_gate=True: out = norm(x) * silu(z)
            - norm_before_gate=False: out = norm(x * silu(z))
552
        """
553
554
555
556
        orig_dtype = x.dtype
        x = x.float()
        weight = self.weight.float()
        z = z.float() if z is not None else None        
557
558
559
560
561
562
563
564
565
        # Apply gating before normalization if needed
        if z is not None and not self.norm_before_gate:
            x = x * F.silu(z)

        # RMS Normalization
        if self.group_size is None:
            # Standard RMS norm across the last dimension
            variance = x.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x * torch.rsqrt(variance + self.eps)
566
            out = x_normed * weight
567
568
569
        else:
            # Group RMS norm
            from einops import rearrange
570

571
572
573
            x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
            variance = x_group.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x_group * torch.rsqrt(variance + self.eps)
574
            out = rearrange(x_normed, "... g d -> ... (g d)") * weight
575
576
577
578
579

        # Apply gating after normalization if needed
        if z is not None and self.norm_before_gate:
            out = out * F.silu(z)

580
        return out.to(orig_dtype)
581
582

    def forward_cuda(
583
        self, x: torch.Tensor, z: torch.Tensor | None = None
584
    ) -> torch.Tensor:
585
586
587
588
589
590
591
592
593
594
595
        from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
        )
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610


class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
611
612
613
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)