layernorm.py 19.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
from typing import Optional
10
from vllm._aiter_ops import rocm_aiter_ops
11
from vllm.model_executor.custom_op import CustomOp
zhuwenwen's avatar
zhuwenwen committed
12

13
14
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
15
    vllm_is_batch_invariant,
16
)
17
from vllm.platforms import current_platform
18
from vllm.utils import direct_register_custom_op
19
from vllm import envs
20

21
22
from lightop import rms_norm_dynamic_per_token_quant

23

24
25
26
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
27
    from vllm import _custom_ops as ops
28

29
    if vllm_is_batch_invariant():
30
        return rms_norm_batch_invariant(x, weight, variance_epsilon)
31
    out = torch.empty_like(x)
32
33
    # if envs.VLLM_USE_OPT_OP:
    if False:
zhuwenwen's avatar
zhuwenwen committed
34
        ops.rms_norm_opt(
zhuwenwen's avatar
zhuwenwen committed
35
36
            x,
            weight,
37
            out,
zhuwenwen's avatar
zhuwenwen committed
38
39
40
41
42
43
44
45
46
            variance_epsilon,
        )
    else:
        ops.rms_norm(
            out,
            x,
            weight,
            variance_epsilon,
        )
47
48
49
50
    return out


def fused_add_rms_norm(
51
52
53
54
55
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
56
    from vllm import _custom_ops as ops
57

58
    if vllm_is_batch_invariant():
59
60
61
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
62
63
    # if envs.VLLM_USE_OPT_OP:
    if False:
zhuwenwen's avatar
zhuwenwen committed
64
        ops.fused_add_rms_norm_opt(
zhuwenwen's avatar
zhuwenwen committed
65
66
67
68
69
70
71
72
73
74
75
76
            x,
            residual,
            weight,
            variance_epsilon,
        )
    else:
        ops.fused_add_rms_norm(
            x,
            residual,
            weight,
            variance_epsilon,
        )
77
78
79
    return x, residual


80
81
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
82
) -> torch.Tensor:
83
    from vllm import _custom_ops as ops
84

85
86
87
88
89
90
91
92
93
94
95
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
        weight,
        bias,
        variance_epsilon,
    )
    return out


96
97
98
99
def dispatch_rocm_rmsnorm_func(
    with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
    use_aiter = use_aiter and dtype in [
100
101
        torch.float16,
        torch.bfloat16,
102
103
104
    ]

    if use_aiter and with_fused_add:
105
        return rocm_aiter_ops.rms_norm2d_with_add
106
    if use_aiter:
107
        return rocm_aiter_ops.rms_norm
108

109
110
111
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
112
    return rms_norm
113
114


115
# --8<-- [start:rms_norm]
116
@CustomOp.register("rms_norm")
117
class RMSNorm(CustomOp):
118
119
120
121
122
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
123

124
125
    # --8<-- [end:rms_norm]

126
127
128
129
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
130
        var_hidden_size: int | None = None,
131
        has_weight: bool = True,
132
        dtype: torch.dtype | None = None,
133
134
    ) -> None:
        super().__init__()
135
136

        self.hidden_size = hidden_size
137
        self.variance_epsilon = eps
138
139
140
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
141
        weight_dtype = dtype or torch.get_default_dtype()
142
        self.has_weight = has_weight
143
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
144
145
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
146
147

        if current_platform.is_rocm():
148
            aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
149
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
150
151
152
                with_fused_add=False,
                dtype=weight_dtype,
                use_aiter=aiter_rmsnorm_enabled,
153
            )
154
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
155
                with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
156
            )
157

158
159
    @staticmethod
    def forward_static(
zhuwenwen's avatar
zhuwenwen committed
160
        # self,
161
        x: torch.Tensor,
162
163
164
165
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
166
        residual: torch.Tensor | None = None,
167
        variance_size_override: int | None = None,
168
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
169
        """PyTorch-native implementation equivalent to forward()."""
zhuwenwen's avatar
zhuwenwen committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        # if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
        #     return self.forward_cuda(x, residual)  
        # else:
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            # residual promoted f16->f32 automatically,
            # otherwise Inductor eliminates the casts to and from f16,
            # increasing memory usage (and complicating pattern matching)
            x = x + residual
            residual = x.to(orig_dtype)

        if x.shape[-1] != hidden_size:
            raise ValueError(
                f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
            )

        if variance_size_override is None:
            x_var = x
189
        else:
zhuwenwen's avatar
zhuwenwen committed
190
            if hidden_size < variance_size_override:
191
                raise ValueError(
zhuwenwen's avatar
zhuwenwen committed
192
193
                    "Expected hidden_size to be at least "
                    f"{variance_size_override}, but found: {hidden_size}"
194
                )
zhuwenwen's avatar
zhuwenwen committed
195

zhuwenwen's avatar
zhuwenwen committed
196
            x_var = x[:, :, :variance_size_override]
zhuwenwen's avatar
zhuwenwen committed
197

zhuwenwen's avatar
zhuwenwen committed
198
        variance = x_var.pow(2).mean(dim=-1, keepdim=True)
zhuwenwen's avatar
zhuwenwen committed
199

zhuwenwen's avatar
zhuwenwen committed
200
201
202
203
204
205
206
207
        x = x * torch.rsqrt(variance + variance_epsilon)
        x = x.to(orig_dtype)
        if weight is not None:
            x = x * weight
        if residual is None:
            return x
        else:
            return x, residual
208

209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

226
    def forward_cuda(
227
228
        self,
        x: torch.Tensor,
229
230
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
231
232
233
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

234
        add_residual = residual is not None
235
        if add_residual:
236
237
238
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
239
240
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)
241

242
243
244
    def forward_hip(
        self,
        x: torch.Tensor,
245
246
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
247
248
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
249

250
        add_residual = residual is not None
251
        if add_residual:
252
253
254
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
zhuwenwen's avatar
zhuwenwen committed
255
        else:
256
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
257
           
zhuwenwen's avatar
zhuwenwen committed
258
259
260
    def forward_apex(
        self,
        x: torch.Tensor,
zhuwenwen's avatar
zhuwenwen committed
261
262
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
263
264
265
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
        
zhuwenwen's avatar
zhuwenwen committed
266
267
268
269
        from apex.normalization.fused_layer_norm import fused_rms_norm_affine
        add_residual = residual is not None

        if add_residual:
zhuwenwen's avatar
zhuwenwen committed
270
271
            return self.rocm_norm_func_with_add(x, residual, self.weight.data,
                                                self.variance_epsilon)
zhuwenwen's avatar
zhuwenwen committed
272
273
        else:
            return fused_rms_norm_affine(x, self.weight.data, torch.Size((x.shape[-1],)), self.variance_epsilon)
274

275
276
277
    def forward_xpu(
        self,
        x: torch.Tensor,
278
279
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
280
281
282
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

283
284
285
286
287
288
289
290
291
292
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
293
        return ops.rms_norm(
294
295
296
297
298
            x,
            self.weight.data,
            self.variance_epsilon,
        )

299
300
301
302
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
    

class FusedRMSNormQuant(nn.Module):
    """Fuse Root mean square normalization and int8 quant.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
        var_hidden_size: int | None = None,
        has_weight: bool = True,
        dtype: torch.dtype | None = None,
    ) -> None:
        super().__init__()

        self.hidden_size = hidden_size
        self.variance_epsilon = eps
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
        weight_dtype = dtype or torch.get_default_dtype()
        self.has_weight = has_weight
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
    
    def forward(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
        quant_dtype: torch.dtype = torch.int8
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        
        x, x_scales = fused_rmsquant(x, self.weight, 
                                     self.variance_epsilon,
                                     quant_dtype, residual)
        return x, x_scales, residual
    

def fused_rmsquant_impl(
    input: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
    residual: Optional[torch.Tensor] = None,
    update_input: Optional[bool] = True
) -> tuple[torch.Tensor, torch.Tensor]:
    output, scales = rms_norm_dynamic_per_token_quant(input, 
                                                   weight,
                                                   epsilon,
                                                   quant_dtype,
                                                   residual,
                                                   update_input)
    return output, scales

def fused_rmsquant_fake(
    input: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
    residual: Optional[torch.Tensor] = None,
    update_input: Optional[bool] = True
) -> tuple[torch.Tensor, torch.Tensor]:
    """Fake implementation for torch.compile"""
    output = torch.zeros_like(input, dtype=quant_dtype)
    scales = torch.ones((input.numel() // input.shape[-1], 1),
                        device=input.device,
                        dtype=torch.float32)
    return output, scales

direct_register_custom_op(
    op_name="rms_norm_dynamic_per_token_quant",
    op_func=fused_rmsquant_impl,
    mutates_args=[],
    fake_impl=fused_rmsquant_fake,
)

def fused_rmsquant(input: torch.Tensor,
    rms_weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
    residual: Optional[torch.Tensor] = None,
    update_input: Optional[bool] = True):
    i_q, _scales = torch.ops.vllm.fused_rmsquant(input=input,
                                                 weight=rms_weight,
                                                 epsilon=epsilon,
                                                quant_dtype=quant_dtype,
                                                residual=residual,
                                                update_input=update_input)
    return i_q, _scales
Woosuk Kwon's avatar
Woosuk Kwon committed
397
398


399
# --8<-- [start:gemma_rms_norm]
400
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
401
402
403
404
405
406
407
408
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

409
410
    # --8<-- [end:gemma_rms_norm]

Woosuk Kwon's avatar
Woosuk Kwon committed
411
412
413
414
415
416
417
418
419
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

420
    @staticmethod
421
    def _forward_static_no_residual(
422
423
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
424
        x: torch.Tensor,
425
426
    ) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward() without residual."""
Woosuk Kwon's avatar
Woosuk Kwon committed
427
        orig_dtype = x.dtype
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(variance + variance_epsilon)
        x = x * (1.0 + weight.float())
        x = x.to(orig_dtype)
        return x

    @staticmethod
    def _forward_static_with_residual(
        weight: torch.Tensor,
        variance_epsilon: float,
        x: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward() with residual."""
        orig_dtype = x.dtype
        x = (
            x.float() + residual.float()
            if orig_dtype == torch.float16
            else x + residual
        )
        residual = x
Woosuk Kwon's avatar
Woosuk Kwon committed
450
451
452

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
453
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
454
455
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
456
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
457
        x = x.to(orig_dtype)
458
        return x, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
459

460
461
462
    def forward_native(
        self,
        x: torch.Tensor,
463
464
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
465
        """PyTorch-native implementation equivalent to forward()."""
466
467
468
469
470
471
472
473
        if residual is None:
            return self._forward_static_no_residual(
                self.weight.data, self.variance_epsilon, x
            )
        else:
            return self._forward_static_with_residual(
                self.weight.data, self.variance_epsilon, x, residual
            )
474

Woosuk Kwon's avatar
Woosuk Kwon committed
475
476
477
    def forward_cuda(
        self,
        x: torch.Tensor,
478
479
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
480
481
482
483
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
484
485
486
487
488
            self._forward_static_no_residual = torch.compile(  # type: ignore
                self._forward_static_no_residual
            )
            self._forward_static_with_residual = torch.compile(  # type: ignore
                self._forward_static_with_residual
489
            )
490
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
491
        return self.forward_native(x, residual)
492
493


494
# --8<-- [start:rms_norm_gated]
495
496
497
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
    """RMS Normalization with optional gating.
498

499
500
501
502
    This is a native PyTorch implementation that supports:
    - Standard RMS normalization
    - Group RMS normalization
    - Optional gating with SiLU activation
503
504
    """

505
506
    # --8<-- [end:rms_norm_gated]

507
508
    def __init__(
        self,
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
        hidden_size: int,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        """Initialize RMSNormGated.

        Args:
            hidden_size: Size of the hidden dimension
            eps: Epsilon for numerical stability
            group_size: If not None, do GroupNorm with each group
                        having group_size elements.
                        group_size=None is equivalent to group_size=hidden_size
                        (i.e. there's only 1 group).
            norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                              If False and z is provided: out = norm(x * silu(z))
            device: Device to create parameters on
            dtype: Data type for parameters
        """
        factory_kwargs = {"device": device, "dtype": dtype}
531
        super().__init__()
532
533
534
535
536
537
        self.eps = eps
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()
538

539
540
    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)
541
542

    def forward_native(
543
        self, x: torch.Tensor, z: torch.Tensor | None = None
544
    ) -> torch.Tensor:
545
546
547
548
549
550
        """
        Native PyTorch implementation of RMS normalization with gating.

        Args:
            x: Input tensor
            z: Optional gating tensor
551

552
553
554
555
556
557
        Returns:
            Normalized (and optionally gated) tensor

        If z is not None:
            - norm_before_gate=True: out = norm(x) * silu(z)
            - norm_before_gate=False: out = norm(x * silu(z))
558
        """
559
560
561
562
        orig_dtype = x.dtype
        x = x.float()
        weight = self.weight.float()
        z = z.float() if z is not None else None        
563
564
565
566
567
568
569
570
571
        # Apply gating before normalization if needed
        if z is not None and not self.norm_before_gate:
            x = x * F.silu(z)

        # RMS Normalization
        if self.group_size is None:
            # Standard RMS norm across the last dimension
            variance = x.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x * torch.rsqrt(variance + self.eps)
572
            out = x_normed * weight
573
574
575
        else:
            # Group RMS norm
            from einops import rearrange
576

577
578
579
            x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
            variance = x_group.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x_group * torch.rsqrt(variance + self.eps)
580
            out = rearrange(x_normed, "... g d -> ... (g d)") * weight
581
582
583
584
585

        # Apply gating after normalization if needed
        if z is not None and self.norm_before_gate:
            out = out * F.silu(z)

586
        return out.to(orig_dtype)
587
588

    def forward_cuda(
589
        self, x: torch.Tensor, z: torch.Tensor | None = None
590
    ) -> torch.Tensor:
591
592
593
594
595
596
597
598
599
600
601
        from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
        )
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616


class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
617
618
619
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)