layernorm.py 19.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
from typing import Optional
10
from vllm._aiter_ops import rocm_aiter_ops
11
from vllm.model_executor.custom_op import CustomOp
zhuwenwen's avatar
zhuwenwen committed
12

13
14
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
15
    vllm_is_batch_invariant,
16
)
17
from vllm.platforms import current_platform
18
from vllm.utils import direct_register_custom_op
19
from vllm import envs
20
21


22
23
24
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
25
    from vllm import _custom_ops as ops
26

27
    if vllm_is_batch_invariant():
28
        return rms_norm_batch_invariant(x, weight, variance_epsilon)
29
    out = torch.empty_like(x)
guanyu1's avatar
guanyu1 committed
30
31
    if envs.VLLM_USE_OPT_OP:
        torch.ops.vllm.rms_norm_opt(
zhuwenwen's avatar
zhuwenwen committed
32
33
            x,
            weight,
34
            out,
zhuwenwen's avatar
zhuwenwen committed
35
            variance_epsilon,
guanyu1's avatar
guanyu1 committed
36
37
            False,
        )#False参数对当前的lightop调用的kernel是多余的
zhuwenwen's avatar
zhuwenwen committed
38
39
40
41
42
43
44
    else:
        ops.rms_norm(
            out,
            x,
            weight,
            variance_epsilon,
        )
45
46
47
48
    return out


def fused_add_rms_norm(
49
50
51
52
53
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
54
    from vllm import _custom_ops as ops
55

56
    if vllm_is_batch_invariant():
57
58
59
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
fanwl's avatar
fanwl committed
60
    if envs.VLLM_USE_OPT_OP:
61
        torch.ops.vllm.fused_add_rms_norm_opt(
zhuwenwen's avatar
zhuwenwen committed
62
63
64
65
66
67
68
69
70
71
72
73
            x,
            residual,
            weight,
            variance_epsilon,
        )
    else:
        ops.fused_add_rms_norm(
            x,
            residual,
            weight,
            variance_epsilon,
        )
74
75
76
    return x, residual


77
78
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
79
) -> torch.Tensor:
80
    from vllm import _custom_ops as ops
81

82
83
84
85
86
87
88
89
90
91
92
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
        weight,
        bias,
        variance_epsilon,
    )
    return out


93
94
95
96
def dispatch_rocm_rmsnorm_func(
    with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
    use_aiter = use_aiter and dtype in [
97
98
        torch.float16,
        torch.bfloat16,
99
100
101
    ]

    if use_aiter and with_fused_add:
102
        return rocm_aiter_ops.rms_norm2d_with_add
103
    if use_aiter:
104
        return rocm_aiter_ops.rms_norm
105

106
107
108
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
109
    return rms_norm
110
111


112
# --8<-- [start:rms_norm]
113
@CustomOp.register("rms_norm")
114
class RMSNorm(CustomOp):
115
116
117
118
119
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
120

121
122
    # --8<-- [end:rms_norm]

123
124
125
126
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
127
        var_hidden_size: int | None = None,
128
        has_weight: bool = True,
129
        dtype: torch.dtype | None = None,
130
131
    ) -> None:
        super().__init__()
132
133

        self.hidden_size = hidden_size
134
        self.variance_epsilon = eps
135
136
137
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
138
        weight_dtype = dtype or torch.get_default_dtype()
139
        self.has_weight = has_weight
140
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
141
142
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
143
144

        if current_platform.is_rocm():
145
            aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
146
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
147
148
149
                with_fused_add=False,
                dtype=weight_dtype,
                use_aiter=aiter_rmsnorm_enabled,
150
            )
151
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
152
                with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
153
            )
154

155
156
    @staticmethod
    def forward_static(
zhuwenwen's avatar
zhuwenwen committed
157
        # self,
158
        x: torch.Tensor,
159
160
161
162
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
163
        residual: torch.Tensor | None = None,
164
        variance_size_override: int | None = None,
165
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
166
        """PyTorch-native implementation equivalent to forward()."""
zhuwenwen's avatar
zhuwenwen committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        # if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
        #     return self.forward_cuda(x, residual)  
        # else:
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            # residual promoted f16->f32 automatically,
            # otherwise Inductor eliminates the casts to and from f16,
            # increasing memory usage (and complicating pattern matching)
            x = x + residual
            residual = x.to(orig_dtype)

        if x.shape[-1] != hidden_size:
            raise ValueError(
                f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
            )

        if variance_size_override is None:
            x_var = x
186
        else:
zhuwenwen's avatar
zhuwenwen committed
187
            if hidden_size < variance_size_override:
188
                raise ValueError(
zhuwenwen's avatar
zhuwenwen committed
189
190
                    "Expected hidden_size to be at least "
                    f"{variance_size_override}, but found: {hidden_size}"
191
                )
zhuwenwen's avatar
zhuwenwen committed
192

zhuwenwen's avatar
zhuwenwen committed
193
            x_var = x[:, :, :variance_size_override]
zhuwenwen's avatar
zhuwenwen committed
194

zhuwenwen's avatar
zhuwenwen committed
195
        variance = x_var.pow(2).mean(dim=-1, keepdim=True)
zhuwenwen's avatar
zhuwenwen committed
196

zhuwenwen's avatar
zhuwenwen committed
197
198
199
200
201
202
203
204
        x = x * torch.rsqrt(variance + variance_epsilon)
        x = x.to(orig_dtype)
        if weight is not None:
            x = x * weight
        if residual is None:
            return x
        else:
            return x, residual
205

206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

223
    def forward_cuda(
224
225
        self,
        x: torch.Tensor,
226
227
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
228
229
230
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

231
        add_residual = residual is not None
232
        if add_residual:
233
234
235
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
236
237
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)
238

239
240
241
    def forward_hip(
        self,
        x: torch.Tensor,
242
243
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
244
245
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
246

247
        add_residual = residual is not None
248
        if add_residual:
249
250
251
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
zhuwenwen's avatar
zhuwenwen committed
252
        else:
253
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
254
           
zhuwenwen's avatar
zhuwenwen committed
255
256
257
    def forward_apex(
        self,
        x: torch.Tensor,
zhuwenwen's avatar
zhuwenwen committed
258
259
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
260
261
262
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
        
zhuwenwen's avatar
zhuwenwen committed
263
264
265
266
        from apex.normalization.fused_layer_norm import fused_rms_norm_affine
        add_residual = residual is not None

        if add_residual:
zhuwenwen's avatar
zhuwenwen committed
267
268
            return self.rocm_norm_func_with_add(x, residual, self.weight.data,
                                                self.variance_epsilon)
zhuwenwen's avatar
zhuwenwen committed
269
270
        else:
            return fused_rms_norm_affine(x, self.weight.data, torch.Size((x.shape[-1],)), self.variance_epsilon)
271

272
273
274
    def forward_xpu(
        self,
        x: torch.Tensor,
275
276
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
277
278
279
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

280
281
282
283
284
285
286
287
288
289
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
290
        return ops.rms_norm(
291
292
293
294
295
            x,
            self.weight.data,
            self.variance_epsilon,
        )

296
297
298
299
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    

class FusedRMSNormQuant(nn.Module):
    """Fuse Root mean square normalization and int8 quant.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
        var_hidden_size: int | None = None,
        has_weight: bool = True,
        dtype: torch.dtype | None = None,
    ) -> None:
        super().__init__()

        self.hidden_size = hidden_size
        self.variance_epsilon = eps
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
        weight_dtype = dtype or torch.get_default_dtype()
        self.has_weight = has_weight
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
    
    def forward(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
        quant_dtype: torch.dtype = torch.int8,
        update_input: Optional[bool] = True
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
wujl5's avatar
wujl5 committed
337
        i_q, i_s = torch.ops.vllm.fused_rmsquant_customer_impl(input=x,
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
                                                 weight=self.weight,
                                                 epsilon=self.variance_epsilon,
                                                 quant_dtype=quant_dtype,
                                                 residual=residual,
                                                 update_input=update_input)

        return i_q, i_s, residual
    

def fused_rmsquant_impl(
    input: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
    residual: Optional[torch.Tensor] = None,
    update_input: Optional[bool] = True
) -> tuple[torch.Tensor, torch.Tensor]:
    output = torch.empty_like(input, device=input.device, dtype=quant_dtype)
    scales = torch.empty((input.numel() // input.shape[-1], 1),
                        device=input.device,
                        dtype=torch.float32)
    
    from lightop.op import rms_norm_dynamic_per_token_quant as ligtop_rms_norm_dynamic_per_token_quant
    ligtop_rms_norm_dynamic_per_token_quant(output, input, weight,
                                               scales, epsilon,
                                               residual, update_input)
    return output, scales

def fused_rmsquant_fake(
    input: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
    residual: Optional[torch.Tensor] = None,
    update_input: Optional[bool] = True
) -> tuple[torch.Tensor, torch.Tensor]:
    """Fake implementation for torch.compile"""
    output = torch.empty_like(input, dtype=quant_dtype)
    scales = torch.empty((input.numel() // input.shape[-1], 1),
                        device=input.device,
                        dtype=torch.float32)
    return output, scales

direct_register_custom_op(
wujl5's avatar
wujl5 committed
382
    op_name="fused_rmsquant_customer_impl",
383
    op_func=fused_rmsquant_impl,
wujl5's avatar
wujl5 committed
384
    mutates_args=["input", "residual"],
385
386
    fake_impl=fused_rmsquant_fake,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
387
388


389
# --8<-- [start:gemma_rms_norm]
390
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
391
392
393
394
395
396
397
398
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

399
400
    # --8<-- [end:gemma_rms_norm]

Woosuk Kwon's avatar
Woosuk Kwon committed
401
402
403
404
405
406
407
408
409
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

410
    @staticmethod
411
    def _forward_static_no_residual(
412
413
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
414
        x: torch.Tensor,
415
416
    ) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward() without residual."""
Woosuk Kwon's avatar
Woosuk Kwon committed
417
        orig_dtype = x.dtype
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(variance + variance_epsilon)
        x = x * (1.0 + weight.float())
        x = x.to(orig_dtype)
        return x

    @staticmethod
    def _forward_static_with_residual(
        weight: torch.Tensor,
        variance_epsilon: float,
        x: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward() with residual."""
        orig_dtype = x.dtype
        x = (
            x.float() + residual.float()
            if orig_dtype == torch.float16
            else x + residual
        )
        residual = x
Woosuk Kwon's avatar
Woosuk Kwon committed
440
441
442

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
443
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
444
445
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
446
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
447
        x = x.to(orig_dtype)
448
        return x, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
449

450
451
452
    def forward_native(
        self,
        x: torch.Tensor,
453
454
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
455
        """PyTorch-native implementation equivalent to forward()."""
456
457
458
459
460
461
462
463
        if residual is None:
            return self._forward_static_no_residual(
                self.weight.data, self.variance_epsilon, x
            )
        else:
            return self._forward_static_with_residual(
                self.weight.data, self.variance_epsilon, x, residual
            )
464

Woosuk Kwon's avatar
Woosuk Kwon committed
465
466
467
    def forward_cuda(
        self,
        x: torch.Tensor,
468
469
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
470
471
472
473
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
474
475
476
477
478
            self._forward_static_no_residual = torch.compile(  # type: ignore
                self._forward_static_no_residual
            )
            self._forward_static_with_residual = torch.compile(  # type: ignore
                self._forward_static_with_residual
479
            )
480
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
481
        return self.forward_native(x, residual)
482
483


484
# --8<-- [start:rms_norm_gated]
485
486
487
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
    """RMS Normalization with optional gating.
488

489
490
491
492
    This is a native PyTorch implementation that supports:
    - Standard RMS normalization
    - Group RMS normalization
    - Optional gating with SiLU activation
493
494
    """

495
496
    # --8<-- [end:rms_norm_gated]

497
498
    def __init__(
        self,
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
        hidden_size: int,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        """Initialize RMSNormGated.

        Args:
            hidden_size: Size of the hidden dimension
            eps: Epsilon for numerical stability
            group_size: If not None, do GroupNorm with each group
                        having group_size elements.
                        group_size=None is equivalent to group_size=hidden_size
                        (i.e. there's only 1 group).
            norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                              If False and z is provided: out = norm(x * silu(z))
            device: Device to create parameters on
            dtype: Data type for parameters
        """
        factory_kwargs = {"device": device, "dtype": dtype}
521
        super().__init__()
522
523
524
525
526
527
        self.eps = eps
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()
528

529
530
    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)
531
532

    def forward_native(
533
        self, x: torch.Tensor, z: torch.Tensor | None = None
534
    ) -> torch.Tensor:
535
536
537
538
539
540
        """
        Native PyTorch implementation of RMS normalization with gating.

        Args:
            x: Input tensor
            z: Optional gating tensor
541

542
543
544
545
546
547
        Returns:
            Normalized (and optionally gated) tensor

        If z is not None:
            - norm_before_gate=True: out = norm(x) * silu(z)
            - norm_before_gate=False: out = norm(x * silu(z))
548
        """
549
550
551
552
        orig_dtype = x.dtype
        x = x.float()
        weight = self.weight.float()
        z = z.float() if z is not None else None        
553
554
555
556
557
558
559
560
561
        # Apply gating before normalization if needed
        if z is not None and not self.norm_before_gate:
            x = x * F.silu(z)

        # RMS Normalization
        if self.group_size is None:
            # Standard RMS norm across the last dimension
            variance = x.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x * torch.rsqrt(variance + self.eps)
562
            out = x_normed * weight
563
564
565
        else:
            # Group RMS norm
            from einops import rearrange
566

567
568
569
            x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
            variance = x_group.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x_group * torch.rsqrt(variance + self.eps)
570
            out = rearrange(x_normed, "... g d -> ... (g d)") * weight
571
572
573
574
575

        # Apply gating after normalization if needed
        if z is not None and self.norm_before_gate:
            out = out * F.silu(z)

576
        return out.to(orig_dtype)
577
578

    def forward_cuda(
579
        self, x: torch.Tensor, z: torch.Tensor | None = None
580
    ) -> torch.Tensor:
581
582
583
584
585
586
587
588
589
590
591
        from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
        )
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606


class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
607
608
609
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)