"vscode:/vscode.git/clone" did not exist on "f1a7696f7fbe15a1a48840a1cd7840d705112b23"
layernorm.py 19.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
from typing import Optional
10
from vllm._aiter_ops import rocm_aiter_ops
11
from vllm.model_executor.custom_op import CustomOp
zhuwenwen's avatar
zhuwenwen committed
12

13
14
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
15
    vllm_is_batch_invariant,
16
)
17
from vllm.platforms import current_platform
18
from vllm.utils import direct_register_custom_op
19
from vllm import envs
20
21


22
23
24
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
25
    from vllm import _custom_ops as ops
26

27
    if vllm_is_batch_invariant():
28
        return rms_norm_batch_invariant(x, weight, variance_epsilon)
29
    out = torch.empty_like(x)
guanyu1's avatar
guanyu1 committed
30
31
    if envs.VLLM_USE_OPT_OP:
        torch.ops.vllm.rms_norm_opt(
zhuwenwen's avatar
zhuwenwen committed
32
33
            x,
            weight,
34
            out,
zhuwenwen's avatar
zhuwenwen committed
35
            variance_epsilon,
guanyu1's avatar
guanyu1 committed
36
37
            False,
        )#False参数对当前的lightop调用的kernel是多余的
zhuwenwen's avatar
zhuwenwen committed
38
39
40
41
42
43
44
    else:
        ops.rms_norm(
            out,
            x,
            weight,
            variance_epsilon,
        )
45
46
47
48
    return out


def fused_add_rms_norm(
49
50
51
52
53
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
54
    from vllm import _custom_ops as ops
55

56
    if vllm_is_batch_invariant():
57
58
59
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
60
61
    # if envs.VLLM_USE_OPT_OP:
    if False:
zhuwenwen's avatar
zhuwenwen committed
62
        ops.fused_add_rms_norm_opt(
zhuwenwen's avatar
zhuwenwen committed
63
64
65
66
67
68
69
70
71
72
73
74
            x,
            residual,
            weight,
            variance_epsilon,
        )
    else:
        ops.fused_add_rms_norm(
            x,
            residual,
            weight,
            variance_epsilon,
        )
75
76
77
    return x, residual


78
79
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
80
) -> torch.Tensor:
81
    from vllm import _custom_ops as ops
82

83
84
85
86
87
88
89
90
91
92
93
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
        weight,
        bias,
        variance_epsilon,
    )
    return out


94
95
96
97
def dispatch_rocm_rmsnorm_func(
    with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
    use_aiter = use_aiter and dtype in [
98
99
        torch.float16,
        torch.bfloat16,
100
101
102
    ]

    if use_aiter and with_fused_add:
103
        return rocm_aiter_ops.rms_norm2d_with_add
104
    if use_aiter:
105
        return rocm_aiter_ops.rms_norm
106

107
108
109
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
110
    return rms_norm
111
112


113
# --8<-- [start:rms_norm]
114
@CustomOp.register("rms_norm")
115
class RMSNorm(CustomOp):
116
117
118
119
120
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
121

122
123
    # --8<-- [end:rms_norm]

124
125
126
127
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
128
        var_hidden_size: int | None = None,
129
        has_weight: bool = True,
130
        dtype: torch.dtype | None = None,
131
132
    ) -> None:
        super().__init__()
133
134

        self.hidden_size = hidden_size
135
        self.variance_epsilon = eps
136
137
138
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
139
        weight_dtype = dtype or torch.get_default_dtype()
140
        self.has_weight = has_weight
141
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
142
143
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
144
145

        if current_platform.is_rocm():
146
            aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
147
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
148
149
150
                with_fused_add=False,
                dtype=weight_dtype,
                use_aiter=aiter_rmsnorm_enabled,
151
            )
152
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
153
                with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
154
            )
155

156
157
    @staticmethod
    def forward_static(
zhuwenwen's avatar
zhuwenwen committed
158
        # self,
159
        x: torch.Tensor,
160
161
162
163
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
164
        residual: torch.Tensor | None = None,
165
        variance_size_override: int | None = None,
166
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
167
        """PyTorch-native implementation equivalent to forward()."""
zhuwenwen's avatar
zhuwenwen committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        # if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
        #     return self.forward_cuda(x, residual)  
        # else:
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            # residual promoted f16->f32 automatically,
            # otherwise Inductor eliminates the casts to and from f16,
            # increasing memory usage (and complicating pattern matching)
            x = x + residual
            residual = x.to(orig_dtype)

        if x.shape[-1] != hidden_size:
            raise ValueError(
                f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
            )

        if variance_size_override is None:
            x_var = x
187
        else:
zhuwenwen's avatar
zhuwenwen committed
188
            if hidden_size < variance_size_override:
189
                raise ValueError(
zhuwenwen's avatar
zhuwenwen committed
190
191
                    "Expected hidden_size to be at least "
                    f"{variance_size_override}, but found: {hidden_size}"
192
                )
zhuwenwen's avatar
zhuwenwen committed
193

zhuwenwen's avatar
zhuwenwen committed
194
            x_var = x[:, :, :variance_size_override]
zhuwenwen's avatar
zhuwenwen committed
195

zhuwenwen's avatar
zhuwenwen committed
196
        variance = x_var.pow(2).mean(dim=-1, keepdim=True)
zhuwenwen's avatar
zhuwenwen committed
197

zhuwenwen's avatar
zhuwenwen committed
198
199
200
201
202
203
204
205
        x = x * torch.rsqrt(variance + variance_epsilon)
        x = x.to(orig_dtype)
        if weight is not None:
            x = x * weight
        if residual is None:
            return x
        else:
            return x, residual
206

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

224
    def forward_cuda(
225
226
        self,
        x: torch.Tensor,
227
228
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
229
230
231
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

232
        add_residual = residual is not None
233
        if add_residual:
234
235
236
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
237
238
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)
239

240
241
242
    def forward_hip(
        self,
        x: torch.Tensor,
243
244
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
245
246
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
247

248
        add_residual = residual is not None
249
        if add_residual:
250
251
252
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
zhuwenwen's avatar
zhuwenwen committed
253
        else:
254
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
255
           
zhuwenwen's avatar
zhuwenwen committed
256
257
258
    def forward_apex(
        self,
        x: torch.Tensor,
zhuwenwen's avatar
zhuwenwen committed
259
260
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
261
262
263
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
        
zhuwenwen's avatar
zhuwenwen committed
264
265
266
267
        from apex.normalization.fused_layer_norm import fused_rms_norm_affine
        add_residual = residual is not None

        if add_residual:
zhuwenwen's avatar
zhuwenwen committed
268
269
            return self.rocm_norm_func_with_add(x, residual, self.weight.data,
                                                self.variance_epsilon)
zhuwenwen's avatar
zhuwenwen committed
270
271
        else:
            return fused_rms_norm_affine(x, self.weight.data, torch.Size((x.shape[-1],)), self.variance_epsilon)
272

273
274
275
    def forward_xpu(
        self,
        x: torch.Tensor,
276
277
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
278
279
280
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

281
282
283
284
285
286
287
288
289
290
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
291
        return ops.rms_norm(
292
293
294
295
296
            x,
            self.weight.data,
            self.variance_epsilon,
        )

297
298
299
300
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
    

class FusedRMSNormQuant(nn.Module):
    """Fuse Root mean square normalization and int8 quant.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
        var_hidden_size: int | None = None,
        has_weight: bool = True,
        dtype: torch.dtype | None = None,
    ) -> None:
        super().__init__()

        self.hidden_size = hidden_size
        self.variance_epsilon = eps
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
        weight_dtype = dtype or torch.get_default_dtype()
        self.has_weight = has_weight
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
    
    def forward(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
        quant_dtype: torch.dtype = torch.int8,
        update_input: Optional[bool] = True
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
wujl5's avatar
wujl5 committed
338
        i_q, i_s = torch.ops.vllm.fused_rmsquant_customer_impl(input=x,
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
                                                 weight=self.weight,
                                                 epsilon=self.variance_epsilon,
                                                 quant_dtype=quant_dtype,
                                                 residual=residual,
                                                 update_input=update_input)

        return i_q, i_s, residual
    

def fused_rmsquant_impl(
    input: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
    residual: Optional[torch.Tensor] = None,
    update_input: Optional[bool] = True
) -> tuple[torch.Tensor, torch.Tensor]:
    output = torch.empty_like(input, device=input.device, dtype=quant_dtype)
    scales = torch.empty((input.numel() // input.shape[-1], 1),
                        device=input.device,
                        dtype=torch.float32)
    
    from lightop.op import rms_norm_dynamic_per_token_quant as ligtop_rms_norm_dynamic_per_token_quant
    ligtop_rms_norm_dynamic_per_token_quant(output, input, weight,
                                               scales, epsilon,
                                               residual, update_input)
    return output, scales

def fused_rmsquant_fake(
    input: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
    residual: Optional[torch.Tensor] = None,
    update_input: Optional[bool] = True
) -> tuple[torch.Tensor, torch.Tensor]:
    """Fake implementation for torch.compile"""
    output = torch.empty_like(input, dtype=quant_dtype)
    scales = torch.empty((input.numel() // input.shape[-1], 1),
                        device=input.device,
                        dtype=torch.float32)
    return output, scales

direct_register_custom_op(
wujl5's avatar
wujl5 committed
383
    op_name="fused_rmsquant_customer_impl",
384
    op_func=fused_rmsquant_impl,
wujl5's avatar
wujl5 committed
385
    mutates_args=["input", "residual"],
386
387
    fake_impl=fused_rmsquant_fake,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
388
389


390
# --8<-- [start:gemma_rms_norm]
391
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
392
393
394
395
396
397
398
399
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

400
401
    # --8<-- [end:gemma_rms_norm]

Woosuk Kwon's avatar
Woosuk Kwon committed
402
403
404
405
406
407
408
409
410
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

411
    @staticmethod
412
    def _forward_static_no_residual(
413
414
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
415
        x: torch.Tensor,
416
417
    ) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward() without residual."""
Woosuk Kwon's avatar
Woosuk Kwon committed
418
        orig_dtype = x.dtype
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(variance + variance_epsilon)
        x = x * (1.0 + weight.float())
        x = x.to(orig_dtype)
        return x

    @staticmethod
    def _forward_static_with_residual(
        weight: torch.Tensor,
        variance_epsilon: float,
        x: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward() with residual."""
        orig_dtype = x.dtype
        x = (
            x.float() + residual.float()
            if orig_dtype == torch.float16
            else x + residual
        )
        residual = x
Woosuk Kwon's avatar
Woosuk Kwon committed
441
442
443

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
444
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
445
446
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
447
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
448
        x = x.to(orig_dtype)
449
        return x, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
450

451
452
453
    def forward_native(
        self,
        x: torch.Tensor,
454
455
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
456
        """PyTorch-native implementation equivalent to forward()."""
457
458
459
460
461
462
463
464
        if residual is None:
            return self._forward_static_no_residual(
                self.weight.data, self.variance_epsilon, x
            )
        else:
            return self._forward_static_with_residual(
                self.weight.data, self.variance_epsilon, x, residual
            )
465

Woosuk Kwon's avatar
Woosuk Kwon committed
466
467
468
    def forward_cuda(
        self,
        x: torch.Tensor,
469
470
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
471
472
473
474
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
475
476
477
478
479
            self._forward_static_no_residual = torch.compile(  # type: ignore
                self._forward_static_no_residual
            )
            self._forward_static_with_residual = torch.compile(  # type: ignore
                self._forward_static_with_residual
480
            )
481
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
482
        return self.forward_native(x, residual)
483
484


485
# --8<-- [start:rms_norm_gated]
486
487
488
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
    """RMS Normalization with optional gating.
489

490
491
492
493
    This is a native PyTorch implementation that supports:
    - Standard RMS normalization
    - Group RMS normalization
    - Optional gating with SiLU activation
494
495
    """

496
497
    # --8<-- [end:rms_norm_gated]

498
499
    def __init__(
        self,
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
        hidden_size: int,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        """Initialize RMSNormGated.

        Args:
            hidden_size: Size of the hidden dimension
            eps: Epsilon for numerical stability
            group_size: If not None, do GroupNorm with each group
                        having group_size elements.
                        group_size=None is equivalent to group_size=hidden_size
                        (i.e. there's only 1 group).
            norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                              If False and z is provided: out = norm(x * silu(z))
            device: Device to create parameters on
            dtype: Data type for parameters
        """
        factory_kwargs = {"device": device, "dtype": dtype}
522
        super().__init__()
523
524
525
526
527
528
        self.eps = eps
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()
529

530
531
    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)
532
533

    def forward_native(
534
        self, x: torch.Tensor, z: torch.Tensor | None = None
535
    ) -> torch.Tensor:
536
537
538
539
540
541
        """
        Native PyTorch implementation of RMS normalization with gating.

        Args:
            x: Input tensor
            z: Optional gating tensor
542

543
544
545
546
547
548
        Returns:
            Normalized (and optionally gated) tensor

        If z is not None:
            - norm_before_gate=True: out = norm(x) * silu(z)
            - norm_before_gate=False: out = norm(x * silu(z))
549
        """
550
551
552
553
        orig_dtype = x.dtype
        x = x.float()
        weight = self.weight.float()
        z = z.float() if z is not None else None        
554
555
556
557
558
559
560
561
562
        # Apply gating before normalization if needed
        if z is not None and not self.norm_before_gate:
            x = x * F.silu(z)

        # RMS Normalization
        if self.group_size is None:
            # Standard RMS norm across the last dimension
            variance = x.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x * torch.rsqrt(variance + self.eps)
563
            out = x_normed * weight
564
565
566
        else:
            # Group RMS norm
            from einops import rearrange
567

568
569
570
            x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
            variance = x_group.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x_group * torch.rsqrt(variance + self.eps)
571
            out = rearrange(x_normed, "... g d -> ... (g d)") * weight
572
573
574
575
576

        # Apply gating after normalization if needed
        if z is not None and self.norm_before_gate:
            out = out * F.silu(z)

577
        return out.to(orig_dtype)
578
579

    def forward_cuda(
580
        self, x: torch.Tensor, z: torch.Tensor | None = None
581
    ) -> torch.Tensor:
582
583
584
585
586
587
588
589
590
591
592
        from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
        )
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607


class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
608
609
610
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)