layernorm.py 19.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
from typing import Optional
10
from vllm._aiter_ops import rocm_aiter_ops
11
from vllm.model_executor.custom_op import CustomOp
zhuwenwen's avatar
zhuwenwen committed
12

13
14
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
15
    vllm_is_batch_invariant,
16
)
17
from vllm.platforms import current_platform
18
from vllm.utils import direct_register_custom_op
19
from vllm import envs
20

wujl5's avatar
wujl5 committed
21
22
from lightop.op import rms_norm_dynamic_per_token_quant as ligtop_rms_norm_dynamic_per_token_quant

23

24

25
26
27
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
28
    from vllm import _custom_ops as ops
29

30
    if vllm_is_batch_invariant():
31
        return rms_norm_batch_invariant(x, weight, variance_epsilon)
32
    out = torch.empty_like(x)
33
34
    # if envs.VLLM_USE_OPT_OP:
    if False:
zhuwenwen's avatar
zhuwenwen committed
35
        ops.rms_norm_opt(
zhuwenwen's avatar
zhuwenwen committed
36
37
            x,
            weight,
38
            out,
zhuwenwen's avatar
zhuwenwen committed
39
40
41
42
43
44
45
46
47
            variance_epsilon,
        )
    else:
        ops.rms_norm(
            out,
            x,
            weight,
            variance_epsilon,
        )
48
49
50
51
    return out


def fused_add_rms_norm(
52
53
54
55
56
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
57
    from vllm import _custom_ops as ops
58

59
    if vllm_is_batch_invariant():
60
61
62
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
63
64
    # if envs.VLLM_USE_OPT_OP:
    if False:
zhuwenwen's avatar
zhuwenwen committed
65
        ops.fused_add_rms_norm_opt(
zhuwenwen's avatar
zhuwenwen committed
66
67
68
69
70
71
72
73
74
75
76
77
            x,
            residual,
            weight,
            variance_epsilon,
        )
    else:
        ops.fused_add_rms_norm(
            x,
            residual,
            weight,
            variance_epsilon,
        )
78
79
80
    return x, residual


81
82
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
83
) -> torch.Tensor:
84
    from vllm import _custom_ops as ops
85

86
87
88
89
90
91
92
93
94
95
96
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
        weight,
        bias,
        variance_epsilon,
    )
    return out


97
98
99
100
def dispatch_rocm_rmsnorm_func(
    with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
    use_aiter = use_aiter and dtype in [
101
102
        torch.float16,
        torch.bfloat16,
103
104
105
    ]

    if use_aiter and with_fused_add:
106
        return rocm_aiter_ops.rms_norm2d_with_add
107
    if use_aiter:
108
        return rocm_aiter_ops.rms_norm
109

110
111
112
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
113
    return rms_norm
114
115


116
# --8<-- [start:rms_norm]
117
@CustomOp.register("rms_norm")
118
class RMSNorm(CustomOp):
119
120
121
122
123
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
124

125
126
    # --8<-- [end:rms_norm]

127
128
129
130
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
131
        var_hidden_size: int | None = None,
132
        has_weight: bool = True,
133
        dtype: torch.dtype | None = None,
134
135
    ) -> None:
        super().__init__()
136
137

        self.hidden_size = hidden_size
138
        self.variance_epsilon = eps
139
140
141
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
142
        weight_dtype = dtype or torch.get_default_dtype()
143
        self.has_weight = has_weight
144
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
145
146
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
147
148

        if current_platform.is_rocm():
149
            aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
150
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
151
152
153
                with_fused_add=False,
                dtype=weight_dtype,
                use_aiter=aiter_rmsnorm_enabled,
154
            )
155
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
156
                with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
157
            )
158

159
160
    @staticmethod
    def forward_static(
zhuwenwen's avatar
zhuwenwen committed
161
        # self,
162
        x: torch.Tensor,
163
164
165
166
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
167
        residual: torch.Tensor | None = None,
168
        variance_size_override: int | None = None,
169
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
170
        """PyTorch-native implementation equivalent to forward()."""
zhuwenwen's avatar
zhuwenwen committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        # if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
        #     return self.forward_cuda(x, residual)  
        # else:
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            # residual promoted f16->f32 automatically,
            # otherwise Inductor eliminates the casts to and from f16,
            # increasing memory usage (and complicating pattern matching)
            x = x + residual
            residual = x.to(orig_dtype)

        if x.shape[-1] != hidden_size:
            raise ValueError(
                f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
            )

        if variance_size_override is None:
            x_var = x
190
        else:
zhuwenwen's avatar
zhuwenwen committed
191
            if hidden_size < variance_size_override:
192
                raise ValueError(
zhuwenwen's avatar
zhuwenwen committed
193
194
                    "Expected hidden_size to be at least "
                    f"{variance_size_override}, but found: {hidden_size}"
195
                )
zhuwenwen's avatar
zhuwenwen committed
196

zhuwenwen's avatar
zhuwenwen committed
197
            x_var = x[:, :, :variance_size_override]
zhuwenwen's avatar
zhuwenwen committed
198

zhuwenwen's avatar
zhuwenwen committed
199
        variance = x_var.pow(2).mean(dim=-1, keepdim=True)
zhuwenwen's avatar
zhuwenwen committed
200

zhuwenwen's avatar
zhuwenwen committed
201
202
203
204
205
206
207
208
        x = x * torch.rsqrt(variance + variance_epsilon)
        x = x.to(orig_dtype)
        if weight is not None:
            x = x * weight
        if residual is None:
            return x
        else:
            return x, residual
209

210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

227
    def forward_cuda(
228
229
        self,
        x: torch.Tensor,
230
231
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
232
233
234
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

235
        add_residual = residual is not None
236
        if add_residual:
237
238
239
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
240
241
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)
242

243
244
245
    def forward_hip(
        self,
        x: torch.Tensor,
246
247
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
248
249
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
250

251
        add_residual = residual is not None
252
        if add_residual:
253
254
255
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
zhuwenwen's avatar
zhuwenwen committed
256
        else:
257
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
258
           
zhuwenwen's avatar
zhuwenwen committed
259
260
261
    def forward_apex(
        self,
        x: torch.Tensor,
zhuwenwen's avatar
zhuwenwen committed
262
263
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
264
265
266
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
        
zhuwenwen's avatar
zhuwenwen committed
267
268
269
270
        from apex.normalization.fused_layer_norm import fused_rms_norm_affine
        add_residual = residual is not None

        if add_residual:
zhuwenwen's avatar
zhuwenwen committed
271
272
            return self.rocm_norm_func_with_add(x, residual, self.weight.data,
                                                self.variance_epsilon)
zhuwenwen's avatar
zhuwenwen committed
273
274
        else:
            return fused_rms_norm_affine(x, self.weight.data, torch.Size((x.shape[-1],)), self.variance_epsilon)
275

276
277
278
    def forward_xpu(
        self,
        x: torch.Tensor,
279
280
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
281
282
283
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

284
285
286
287
288
289
290
291
292
293
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
294
        return ops.rms_norm(
295
296
297
298
299
            x,
            self.weight.data,
            self.variance_epsilon,
        )

300
301
302
303
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
    

class FusedRMSNormQuant(nn.Module):
    """Fuse Root mean square normalization and int8 quant.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
        var_hidden_size: int | None = None,
        has_weight: bool = True,
        dtype: torch.dtype | None = None,
    ) -> None:
        super().__init__()

        self.hidden_size = hidden_size
        self.variance_epsilon = eps
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
        weight_dtype = dtype or torch.get_default_dtype()
        self.has_weight = has_weight
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
    
    def forward(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
wujl5's avatar
wujl5 committed
338
339
        quant_dtype: torch.dtype = torch.int8,
        update_input: Optional[bool] = True
340
341
342
343
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        
        x, x_scales = fused_rmsquant(x, self.weight, 
                                     self.variance_epsilon,
wujl5's avatar
wujl5 committed
344
345
                                     quant_dtype, residual,
                                     update_input)
346
347
348
349
350
351
352
353
354
355
356
        return x, x_scales, residual
    

def fused_rmsquant_impl(
    input: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
    residual: Optional[torch.Tensor] = None,
    update_input: Optional[bool] = True
) -> tuple[torch.Tensor, torch.Tensor]:
wujl5's avatar
wujl5 committed
357
358
359
360
361
362
363
    output = torch.empty_like(input, device=input.device, dtype=quant_dtype)
    scales = torch.empty((input.numel() // input.shape[-1], 1),
                        device=input.device,
                        dtype=torch.float32)
    ligtop_rms_norm_dynamic_per_token_quant(output, input, weight,
                                               scales, epsilon,
                                               residual, update_input)
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    return output, scales

def fused_rmsquant_fake(
    input: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
    residual: Optional[torch.Tensor] = None,
    update_input: Optional[bool] = True
) -> tuple[torch.Tensor, torch.Tensor]:
    """Fake implementation for torch.compile"""
    output = torch.zeros_like(input, dtype=quant_dtype)
    scales = torch.ones((input.numel() // input.shape[-1], 1),
                        device=input.device,
                        dtype=torch.float32)
    return output, scales

wujl5's avatar
wujl5 committed
381
382
383
# from torch.library import Library
# customer_lib = Library("customer_", "FRAGMENT")

384
385
386
387
388
direct_register_custom_op(
    op_name="rms_norm_dynamic_per_token_quant",
    op_func=fused_rmsquant_impl,
    mutates_args=[],
    fake_impl=fused_rmsquant_fake,
wujl5's avatar
wujl5 committed
389
    # target_lib=customer_lib,
390
391
392
393
394
395
396
397
)

def fused_rmsquant(input: torch.Tensor,
    rms_weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
    residual: Optional[torch.Tensor] = None,
    update_input: Optional[bool] = True):
wujl5's avatar
wujl5 committed
398
399
400
401
402
403
404
    from lmslim.quantize.quant_ops import lm_faster_rmsquant
    i_q, _scales = lm_faster_rmsquant(input=input,
                                      rms_weight=rms_weight,
                                      epsilon=epsilon,
                                      quant_dtype=quant_dtype,
                                      residual=residual,
                                      update_input=update_input)
405
    return i_q, _scales
Woosuk Kwon's avatar
Woosuk Kwon committed
406
407


408
# --8<-- [start:gemma_rms_norm]
409
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
410
411
412
413
414
415
416
417
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

418
419
    # --8<-- [end:gemma_rms_norm]

Woosuk Kwon's avatar
Woosuk Kwon committed
420
421
422
423
424
425
426
427
428
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

429
    @staticmethod
430
    def _forward_static_no_residual(
431
432
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
433
        x: torch.Tensor,
434
435
    ) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward() without residual."""
Woosuk Kwon's avatar
Woosuk Kwon committed
436
        orig_dtype = x.dtype
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(variance + variance_epsilon)
        x = x * (1.0 + weight.float())
        x = x.to(orig_dtype)
        return x

    @staticmethod
    def _forward_static_with_residual(
        weight: torch.Tensor,
        variance_epsilon: float,
        x: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward() with residual."""
        orig_dtype = x.dtype
        x = (
            x.float() + residual.float()
            if orig_dtype == torch.float16
            else x + residual
        )
        residual = x
Woosuk Kwon's avatar
Woosuk Kwon committed
459
460
461

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
462
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
463
464
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
465
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
466
        x = x.to(orig_dtype)
467
        return x, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
468

469
470
471
    def forward_native(
        self,
        x: torch.Tensor,
472
473
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
474
        """PyTorch-native implementation equivalent to forward()."""
475
476
477
478
479
480
481
482
        if residual is None:
            return self._forward_static_no_residual(
                self.weight.data, self.variance_epsilon, x
            )
        else:
            return self._forward_static_with_residual(
                self.weight.data, self.variance_epsilon, x, residual
            )
483

Woosuk Kwon's avatar
Woosuk Kwon committed
484
485
486
    def forward_cuda(
        self,
        x: torch.Tensor,
487
488
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
489
490
491
492
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
493
494
495
496
497
            self._forward_static_no_residual = torch.compile(  # type: ignore
                self._forward_static_no_residual
            )
            self._forward_static_with_residual = torch.compile(  # type: ignore
                self._forward_static_with_residual
498
            )
499
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
500
        return self.forward_native(x, residual)
501
502


503
# --8<-- [start:rms_norm_gated]
504
505
506
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
    """RMS Normalization with optional gating.
507

508
509
510
511
    This is a native PyTorch implementation that supports:
    - Standard RMS normalization
    - Group RMS normalization
    - Optional gating with SiLU activation
512
513
    """

514
515
    # --8<-- [end:rms_norm_gated]

516
517
    def __init__(
        self,
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
        hidden_size: int,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        """Initialize RMSNormGated.

        Args:
            hidden_size: Size of the hidden dimension
            eps: Epsilon for numerical stability
            group_size: If not None, do GroupNorm with each group
                        having group_size elements.
                        group_size=None is equivalent to group_size=hidden_size
                        (i.e. there's only 1 group).
            norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                              If False and z is provided: out = norm(x * silu(z))
            device: Device to create parameters on
            dtype: Data type for parameters
        """
        factory_kwargs = {"device": device, "dtype": dtype}
540
        super().__init__()
541
542
543
544
545
546
        self.eps = eps
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()
547

548
549
    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)
550
551

    def forward_native(
552
        self, x: torch.Tensor, z: torch.Tensor | None = None
553
    ) -> torch.Tensor:
554
555
556
557
558
559
        """
        Native PyTorch implementation of RMS normalization with gating.

        Args:
            x: Input tensor
            z: Optional gating tensor
560

561
562
563
564
565
566
        Returns:
            Normalized (and optionally gated) tensor

        If z is not None:
            - norm_before_gate=True: out = norm(x) * silu(z)
            - norm_before_gate=False: out = norm(x * silu(z))
567
        """
568
569
570
571
        orig_dtype = x.dtype
        x = x.float()
        weight = self.weight.float()
        z = z.float() if z is not None else None        
572
573
574
575
576
577
578
579
580
        # Apply gating before normalization if needed
        if z is not None and not self.norm_before_gate:
            x = x * F.silu(z)

        # RMS Normalization
        if self.group_size is None:
            # Standard RMS norm across the last dimension
            variance = x.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x * torch.rsqrt(variance + self.eps)
581
            out = x_normed * weight
582
583
584
        else:
            # Group RMS norm
            from einops import rearrange
585

586
587
588
            x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
            variance = x_group.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x_group * torch.rsqrt(variance + self.eps)
589
            out = rearrange(x_normed, "... g d -> ... (g d)") * weight
590
591
592
593
594

        # Apply gating after normalization if needed
        if z is not None and self.norm_before_gate:
            out = out * F.silu(z)

595
        return out.to(orig_dtype)
596
597

    def forward_cuda(
598
        self, x: torch.Tensor, z: torch.Tensor | None = None
599
    ) -> torch.Tensor:
600
601
602
603
604
605
606
607
608
609
610
        from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
        )
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625


class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
626
627
628
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)