layernorm.py 15.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
from vllm._aiter_ops import rocm_aiter_ops
10
from vllm.model_executor.custom_op import CustomOp
zhuwenwen's avatar
zhuwenwen committed
11

12
13
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
14
    vllm_is_batch_invariant,
15
)
16
from vllm.platforms import current_platform
17
from vllm import envs
18
import lightop as op
19
20


21
22
23
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
24
    from vllm import _custom_ops as ops
25

26
    if vllm_is_batch_invariant():
27
        return rms_norm_batch_invariant(x, weight, variance_epsilon)
28
    out = torch.empty_like(x)
29
30
    # if envs.VLLM_USE_OPT_OP:
    if False:
31
        op.rmsnorm_forward(
zhuwenwen's avatar
zhuwenwen committed
32
33
            x,
            weight,
34
            out,
zhuwenwen's avatar
zhuwenwen committed
35
36
37
38
39
40
41
42
43
            variance_epsilon,
        )
    else:
        ops.rms_norm(
            out,
            x,
            weight,
            variance_epsilon,
        )
44
45
46
47
    return out


def fused_add_rms_norm(
48
49
50
51
52
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
53
    from vllm import _custom_ops as ops
54

55
    if vllm_is_batch_invariant():
56
57
58
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
59
60
    # if envs.VLLM_USE_OPT_OP:
    if False:
61
        op.rn_add_forward_autograd(
zhuwenwen's avatar
zhuwenwen committed
62
63
64
65
66
67
68
69
70
71
72
73
            x,
            residual,
            weight,
            variance_epsilon,
        )
    else:
        ops.fused_add_rms_norm(
            x,
            residual,
            weight,
            variance_epsilon,
        )
74
75
76
    return x, residual


77
78
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
79
) -> torch.Tensor:
80
    from vllm import _custom_ops as ops
81

82
83
84
85
86
87
88
89
90
91
92
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
        weight,
        bias,
        variance_epsilon,
    )
    return out


93
94
95
96
def dispatch_rocm_rmsnorm_func(
    with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
    use_aiter = use_aiter and dtype in [
97
98
        torch.float16,
        torch.bfloat16,
99
100
101
    ]

    if use_aiter and with_fused_add:
102
        return rocm_aiter_ops.rms_norm2d_with_add
103
    if use_aiter:
104
        return rocm_aiter_ops.rms_norm
105

106
107
108
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
109
    return rms_norm
110
111


112
# --8<-- [start:rms_norm]
113
@CustomOp.register("rms_norm")
114
class RMSNorm(CustomOp):
115
116
117
118
119
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
120

121
122
    # --8<-- [end:rms_norm]

123
124
125
126
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
127
        var_hidden_size: int | None = None,
128
        has_weight: bool = True,
129
        dtype: torch.dtype | None = None,
130
131
    ) -> None:
        super().__init__()
132
133

        self.hidden_size = hidden_size
134
        self.variance_epsilon = eps
135
136
137
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
138
        weight_dtype = dtype or torch.get_default_dtype()
139
        self.has_weight = has_weight
140
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
141
142
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
143
144

        if current_platform.is_rocm():
145
            aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
146
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
147
148
149
                with_fused_add=False,
                dtype=weight_dtype,
                use_aiter=aiter_rmsnorm_enabled,
150
            )
151
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
152
                with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
153
            )
154

155
156
    @staticmethod
    def forward_static(
zhuwenwen's avatar
zhuwenwen committed
157
        # self,
158
        x: torch.Tensor,
159
160
161
162
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
163
        residual: torch.Tensor | None = None,
164
        variance_size_override: int | None = None,
165
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
166
        """PyTorch-native implementation equivalent to forward()."""
zhuwenwen's avatar
zhuwenwen committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        # if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
        #     return self.forward_cuda(x, residual)  
        # else:
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            # residual promoted f16->f32 automatically,
            # otherwise Inductor eliminates the casts to and from f16,
            # increasing memory usage (and complicating pattern matching)
            x = x + residual
            residual = x.to(orig_dtype)

        if x.shape[-1] != hidden_size:
            raise ValueError(
                f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
            )

        if variance_size_override is None:
            x_var = x
186
        else:
zhuwenwen's avatar
zhuwenwen committed
187
            if hidden_size < variance_size_override:
188
                raise ValueError(
zhuwenwen's avatar
zhuwenwen committed
189
190
                    "Expected hidden_size to be at least "
                    f"{variance_size_override}, but found: {hidden_size}"
191
                )
zhuwenwen's avatar
zhuwenwen committed
192

zhuwenwen's avatar
zhuwenwen committed
193
            x_var = x[:, :, :variance_size_override]
zhuwenwen's avatar
zhuwenwen committed
194

zhuwenwen's avatar
zhuwenwen committed
195
        variance = x_var.pow(2).mean(dim=-1, keepdim=True)
zhuwenwen's avatar
zhuwenwen committed
196
197


zhuwenwen's avatar
zhuwenwen committed
198
199
200
201
202
203
204
205
        x = x * torch.rsqrt(variance + variance_epsilon)
        x = x.to(orig_dtype)
        if weight is not None:
            x = x * weight
        if residual is None:
            return x
        else:
            return x, residual
206

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

224
    def forward_cuda(
225
226
        self,
        x: torch.Tensor,
227
228
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
229
230
231
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

232
        add_residual = residual is not None
233
        if add_residual:
234
235
236
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
237
238
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)
239

240
241
242
    def forward_hip(
        self,
        x: torch.Tensor,
243
244
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
245
246
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
247

248
        add_residual = residual is not None
249
        if add_residual:
250
251
252
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
zhuwenwen's avatar
zhuwenwen committed
253
        else:
254
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
255
           
zhuwenwen's avatar
zhuwenwen committed
256
257
258
    def forward_apex(
        self,
        x: torch.Tensor,
zhuwenwen's avatar
zhuwenwen committed
259
260
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
261
262
263
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
        
zhuwenwen's avatar
zhuwenwen committed
264
265
266
267
        from apex.normalization.fused_layer_norm import fused_rms_norm_affine
        add_residual = residual is not None

        if add_residual:
zhuwenwen's avatar
zhuwenwen committed
268
269
            return self.rocm_norm_func_with_add(x, residual, self.weight.data,
                                                self.variance_epsilon)
zhuwenwen's avatar
zhuwenwen committed
270
271
        else:
            return fused_rms_norm_affine(x, self.weight.data, torch.Size((x.shape[-1],)), self.variance_epsilon)
272

273
274
275
    def forward_xpu(
        self,
        x: torch.Tensor,
276
277
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
278
279
280
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

281
282
283
284
285
286
287
288
289
290
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
291
        return ops.rms_norm(
292
293
294
295
296
            x,
            self.weight.data,
            self.variance_epsilon,
        )

297
298
299
300
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302


303
# --8<-- [start:gemma_rms_norm]
304
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
305
306
307
308
309
310
311
312
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

313
314
    # --8<-- [end:gemma_rms_norm]

Woosuk Kwon's avatar
Woosuk Kwon committed
315
316
317
318
319
320
321
322
323
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

324
325
326
327
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
328
        x: torch.Tensor,
329
330
        residual: torch.Tensor | None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
331
332
333
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
334
335
336
337
338
            x = (
                x.float() + residual.float()
                if orig_dtype == torch.float16
                else x + residual
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
339
340
341
342
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
343
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
344
345
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
346
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
347
348
349
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

350
351
352
    def forward_native(
        self,
        x: torch.Tensor,
353
354
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
355
        """PyTorch-native implementation equivalent to forward()."""
356
        return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)
357

Woosuk Kwon's avatar
Woosuk Kwon committed
358
359
360
    def forward_cuda(
        self,
        x: torch.Tensor,
361
362
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
363
364
365
366
367
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
368
369
                self.forward_static
            )
370
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
371
        return self.forward_native(x, residual)
372
373


374
# --8<-- [start:rms_norm_gated]
375
376
377
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
    """RMS Normalization with optional gating.
378

379
380
381
382
    This is a native PyTorch implementation that supports:
    - Standard RMS normalization
    - Group RMS normalization
    - Optional gating with SiLU activation
383
384
    """

385
386
    # --8<-- [end:rms_norm_gated]

387
388
    def __init__(
        self,
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
        hidden_size: int,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        """Initialize RMSNormGated.

        Args:
            hidden_size: Size of the hidden dimension
            eps: Epsilon for numerical stability
            group_size: If not None, do GroupNorm with each group
                        having group_size elements.
                        group_size=None is equivalent to group_size=hidden_size
                        (i.e. there's only 1 group).
            norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                              If False and z is provided: out = norm(x * silu(z))
            device: Device to create parameters on
            dtype: Data type for parameters
        """
        factory_kwargs = {"device": device, "dtype": dtype}
411
        super().__init__()
412
413
414
415
416
417
        self.eps = eps
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()
418

419
420
    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)
421
422

    def forward_native(
423
        self, x: torch.Tensor, z: torch.Tensor | None = None
424
    ) -> torch.Tensor:
425
426
427
428
429
430
        """
        Native PyTorch implementation of RMS normalization with gating.

        Args:
            x: Input tensor
            z: Optional gating tensor
431

432
433
434
435
436
437
        Returns:
            Normalized (and optionally gated) tensor

        If z is not None:
            - norm_before_gate=True: out = norm(x) * silu(z)
            - norm_before_gate=False: out = norm(x * silu(z))
438
        """
439
440
441
442
443
444
445
446
447
448
449
450
451
        # Apply gating before normalization if needed
        if z is not None and not self.norm_before_gate:
            x = x * F.silu(z)

        # RMS Normalization
        if self.group_size is None:
            # Standard RMS norm across the last dimension
            variance = x.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x * torch.rsqrt(variance + self.eps)
            out = x_normed * self.weight
        else:
            # Group RMS norm
            from einops import rearrange
452

453
454
455
456
457
458
459
460
461
462
            x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
            variance = x_group.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x_group * torch.rsqrt(variance + self.eps)
            out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight

        # Apply gating after normalization if needed
        if z is not None and self.norm_before_gate:
            out = out * F.silu(z)

        return out
463
464

    def forward_cuda(
465
        self, x: torch.Tensor, z: torch.Tensor | None = None
466
    ) -> torch.Tensor:
467
468
469
470
471
472
473
474
475
476
477
        from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
        )
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492


class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
493
494
495
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)