layernorm.py 16 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
from vllm._aiter_ops import rocm_aiter_ops
10
from vllm.model_executor.custom_op import CustomOp
zhuwenwen's avatar
zhuwenwen committed
11

12
13
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
14
    vllm_is_batch_invariant,
15
)
16
from vllm.platforms import current_platform
17
from vllm import envs
18
19


20
21
22
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
23
    from vllm import _custom_ops as ops
24

25
    if vllm_is_batch_invariant():
26
        return rms_norm_batch_invariant(x, weight, variance_epsilon)
27
    out = torch.empty_like(x)
28
29
    # if envs.VLLM_USE_OPT_OP:
    if False:
zhuwenwen's avatar
zhuwenwen committed
30
        ops.rms_norm_opt(
zhuwenwen's avatar
zhuwenwen committed
31
32
            x,
            weight,
33
            out,
zhuwenwen's avatar
zhuwenwen committed
34
35
36
37
38
39
40
41
42
            variance_epsilon,
        )
    else:
        ops.rms_norm(
            out,
            x,
            weight,
            variance_epsilon,
        )
43
44
45
46
    return out


def fused_add_rms_norm(
47
48
49
50
51
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
52
    from vllm import _custom_ops as ops
53

54
    if vllm_is_batch_invariant():
55
56
57
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
58
59
    # if envs.VLLM_USE_OPT_OP:
    if False:
zhuwenwen's avatar
zhuwenwen committed
60
        ops.fused_add_rms_norm_opt(
zhuwenwen's avatar
zhuwenwen committed
61
62
63
64
65
66
67
68
69
70
71
72
            x,
            residual,
            weight,
            variance_epsilon,
        )
    else:
        ops.fused_add_rms_norm(
            x,
            residual,
            weight,
            variance_epsilon,
        )
73
74
75
    return x, residual


76
77
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
78
) -> torch.Tensor:
79
    from vllm import _custom_ops as ops
80

81
82
83
84
85
86
87
88
89
90
91
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
        weight,
        bias,
        variance_epsilon,
    )
    return out


92
93
94
95
def dispatch_rocm_rmsnorm_func(
    with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
    use_aiter = use_aiter and dtype in [
96
97
        torch.float16,
        torch.bfloat16,
98
99
100
    ]

    if use_aiter and with_fused_add:
101
        return rocm_aiter_ops.rms_norm2d_with_add
102
    if use_aiter:
103
        return rocm_aiter_ops.rms_norm
104

105
106
107
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
108
    return rms_norm
109
110


111
# --8<-- [start:rms_norm]
112
@CustomOp.register("rms_norm")
113
class RMSNorm(CustomOp):
114
115
116
117
118
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
119

120
121
    # --8<-- [end:rms_norm]

122
123
124
125
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
126
        var_hidden_size: int | None = None,
127
        has_weight: bool = True,
128
        dtype: torch.dtype | None = None,
129
130
    ) -> None:
        super().__init__()
131
132

        self.hidden_size = hidden_size
133
        self.variance_epsilon = eps
134
135
136
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
137
        weight_dtype = dtype or torch.get_default_dtype()
138
        self.has_weight = has_weight
139
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
140
141
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
142
143

        if current_platform.is_rocm():
144
            aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
145
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
146
147
148
                with_fused_add=False,
                dtype=weight_dtype,
                use_aiter=aiter_rmsnorm_enabled,
149
            )
150
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
151
                with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
152
            )
153

154
155
    @staticmethod
    def forward_static(
zhuwenwen's avatar
zhuwenwen committed
156
        # self,
157
        x: torch.Tensor,
158
159
160
161
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
162
        residual: torch.Tensor | None = None,
163
        variance_size_override: int | None = None,
164
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
165
        """PyTorch-native implementation equivalent to forward()."""
zhuwenwen's avatar
zhuwenwen committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        # if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
        #     return self.forward_cuda(x, residual)  
        # else:
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            # residual promoted f16->f32 automatically,
            # otherwise Inductor eliminates the casts to and from f16,
            # increasing memory usage (and complicating pattern matching)
            x = x + residual
            residual = x.to(orig_dtype)

        if x.shape[-1] != hidden_size:
            raise ValueError(
                f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
            )

        if variance_size_override is None:
            x_var = x
185
        else:
zhuwenwen's avatar
zhuwenwen committed
186
            if hidden_size < variance_size_override:
187
                raise ValueError(
zhuwenwen's avatar
zhuwenwen committed
188
189
                    "Expected hidden_size to be at least "
                    f"{variance_size_override}, but found: {hidden_size}"
190
                )
zhuwenwen's avatar
zhuwenwen committed
191

zhuwenwen's avatar
zhuwenwen committed
192
            x_var = x[:, :, :variance_size_override]
zhuwenwen's avatar
zhuwenwen committed
193

zhuwenwen's avatar
zhuwenwen committed
194
        variance = x_var.pow(2).mean(dim=-1, keepdim=True)
zhuwenwen's avatar
zhuwenwen committed
195

zhuwenwen's avatar
zhuwenwen committed
196
197
198
199
200
201
202
203
        x = x * torch.rsqrt(variance + variance_epsilon)
        x = x.to(orig_dtype)
        if weight is not None:
            x = x * weight
        if residual is None:
            return x
        else:
            return x, residual
204

205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

222
    def forward_cuda(
223
224
        self,
        x: torch.Tensor,
225
226
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
227
228
229
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

230
        add_residual = residual is not None
231
        if add_residual:
232
233
234
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
235
236
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)
237

238
239
240
    def forward_hip(
        self,
        x: torch.Tensor,
241
242
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
243
244
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
245

246
        add_residual = residual is not None
247
        if add_residual:
248
249
250
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
zhuwenwen's avatar
zhuwenwen committed
251
        else:
252
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
253
           
zhuwenwen's avatar
zhuwenwen committed
254
255
256
    def forward_apex(
        self,
        x: torch.Tensor,
zhuwenwen's avatar
zhuwenwen committed
257
258
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
259
260
261
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
        
zhuwenwen's avatar
zhuwenwen committed
262
263
264
265
        from apex.normalization.fused_layer_norm import fused_rms_norm_affine
        add_residual = residual is not None

        if add_residual:
zhuwenwen's avatar
zhuwenwen committed
266
267
            return self.rocm_norm_func_with_add(x, residual, self.weight.data,
                                                self.variance_epsilon)
zhuwenwen's avatar
zhuwenwen committed
268
269
        else:
            return fused_rms_norm_affine(x, self.weight.data, torch.Size((x.shape[-1],)), self.variance_epsilon)
270

271
272
273
    def forward_xpu(
        self,
        x: torch.Tensor,
274
275
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
276
277
278
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

279
280
281
282
283
284
285
286
287
288
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
289
        return ops.rms_norm(
290
291
292
293
294
            x,
            self.weight.data,
            self.variance_epsilon,
        )

295
296
297
298
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
299
300


301
# --8<-- [start:gemma_rms_norm]
302
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
303
304
305
306
307
308
309
310
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

311
312
    # --8<-- [end:gemma_rms_norm]

Woosuk Kwon's avatar
Woosuk Kwon committed
313
314
315
316
317
318
319
320
321
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

322
    @staticmethod
323
    def _forward_static_no_residual(
324
325
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
326
        x: torch.Tensor,
327
328
    ) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward() without residual."""
Woosuk Kwon's avatar
Woosuk Kwon committed
329
        orig_dtype = x.dtype
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(variance + variance_epsilon)
        x = x * (1.0 + weight.float())
        x = x.to(orig_dtype)
        return x

    @staticmethod
    def _forward_static_with_residual(
        weight: torch.Tensor,
        variance_epsilon: float,
        x: torch.Tensor,
        residual: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward() with residual."""
        orig_dtype = x.dtype
        x = (
            x.float() + residual.float()
            if orig_dtype == torch.float16
            else x + residual
        )
        residual = x
Woosuk Kwon's avatar
Woosuk Kwon committed
352
353
354

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
355
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
356
357
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
358
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
359
        x = x.to(orig_dtype)
360
        return x, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
361

362
363
364
    def forward_native(
        self,
        x: torch.Tensor,
365
366
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
367
        """PyTorch-native implementation equivalent to forward()."""
368
369
370
371
372
373
374
375
        if residual is None:
            return self._forward_static_no_residual(
                self.weight.data, self.variance_epsilon, x
            )
        else:
            return self._forward_static_with_residual(
                self.weight.data, self.variance_epsilon, x, residual
            )
376

Woosuk Kwon's avatar
Woosuk Kwon committed
377
378
379
    def forward_cuda(
        self,
        x: torch.Tensor,
380
381
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
382
383
384
385
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
386
387
388
389
390
            self._forward_static_no_residual = torch.compile(  # type: ignore
                self._forward_static_no_residual
            )
            self._forward_static_with_residual = torch.compile(  # type: ignore
                self._forward_static_with_residual
391
            )
392
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
393
        return self.forward_native(x, residual)
394
395


396
# --8<-- [start:rms_norm_gated]
397
398
399
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
    """RMS Normalization with optional gating.
400

401
402
403
404
    This is a native PyTorch implementation that supports:
    - Standard RMS normalization
    - Group RMS normalization
    - Optional gating with SiLU activation
405
406
    """

407
408
    # --8<-- [end:rms_norm_gated]

409
410
    def __init__(
        self,
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
        hidden_size: int,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        """Initialize RMSNormGated.

        Args:
            hidden_size: Size of the hidden dimension
            eps: Epsilon for numerical stability
            group_size: If not None, do GroupNorm with each group
                        having group_size elements.
                        group_size=None is equivalent to group_size=hidden_size
                        (i.e. there's only 1 group).
            norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                              If False and z is provided: out = norm(x * silu(z))
            device: Device to create parameters on
            dtype: Data type for parameters
        """
        factory_kwargs = {"device": device, "dtype": dtype}
433
        super().__init__()
434
435
436
437
438
439
        self.eps = eps
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()
440

441
442
    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)
443
444

    def forward_native(
445
        self, x: torch.Tensor, z: torch.Tensor | None = None
446
    ) -> torch.Tensor:
447
448
449
450
451
452
        """
        Native PyTorch implementation of RMS normalization with gating.

        Args:
            x: Input tensor
            z: Optional gating tensor
453

454
455
456
457
458
459
        Returns:
            Normalized (and optionally gated) tensor

        If z is not None:
            - norm_before_gate=True: out = norm(x) * silu(z)
            - norm_before_gate=False: out = norm(x * silu(z))
460
        """
461
462
463
464
        orig_dtype = x.dtype
        x = x.float()
        weight = self.weight.float()
        z = z.float() if z is not None else None        
465
466
467
468
469
470
471
472
473
        # Apply gating before normalization if needed
        if z is not None and not self.norm_before_gate:
            x = x * F.silu(z)

        # RMS Normalization
        if self.group_size is None:
            # Standard RMS norm across the last dimension
            variance = x.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x * torch.rsqrt(variance + self.eps)
474
            out = x_normed * weight
475
476
477
        else:
            # Group RMS norm
            from einops import rearrange
478

479
480
481
            x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
            variance = x_group.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x_group * torch.rsqrt(variance + self.eps)
482
            out = rearrange(x_normed, "... g d -> ... (g d)") * weight
483
484
485
486
487

        # Apply gating after normalization if needed
        if z is not None and self.norm_before_gate:
            out = out * F.silu(z)

488
        return out.to(orig_dtype)
489
490

    def forward_cuda(
491
        self, x: torch.Tensor, z: torch.Tensor | None = None
492
    ) -> torch.Tensor:
493
494
495
496
497
498
499
500
501
502
503
        from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
        )
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518


class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
519
520
521
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)