layernorm.py 13.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
from vllm._aiter_ops import rocm_aiter_ops
10
from vllm.model_executor.custom_op import CustomOp
11
12
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
13
    vllm_is_batch_invariant,
14
)
15
16
17
from vllm.platforms import current_platform


18
19
20
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
21
    from vllm import _custom_ops as ops
22

23
    if vllm_is_batch_invariant():
24
        return rms_norm_batch_invariant(x, weight, variance_epsilon)
25
26
27
28
29
30
31
32
33
34
35
    out = torch.empty_like(x)
    ops.rms_norm(
        out,
        x,
        weight,
        variance_epsilon,
    )
    return out


def fused_add_rms_norm(
36
37
38
39
40
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
41
    from vllm import _custom_ops as ops
42

43
    if vllm_is_batch_invariant():
44
45
46
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
47
48
49
50
51
52
53
54
55
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


56
57
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
58
) -> torch.Tensor:
59
    from vllm import _custom_ops as ops
60

61
62
63
64
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
65
        weight,
66
        bias,
67
68
        variance_epsilon,
    )
69
    return out
70
71


72
73
74
75
def dispatch_rocm_rmsnorm_func(
    with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
    use_aiter = use_aiter and dtype in [
76
77
        torch.float16,
        torch.bfloat16,
78
79
80
    ]

    if use_aiter and with_fused_add:
81
        return rocm_aiter_ops.rms_norm2d_with_add
82
    if use_aiter:
83
        return rocm_aiter_ops.rms_norm
84

85
86
87
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
88
    return rms_norm
89
90


91
# --8<-- [start:rms_norm]
92
@CustomOp.register("rms_norm")
93
class RMSNorm(CustomOp):
94
95
96
97
98
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
99

100
101
    # --8<-- [end:rms_norm]

102
103
104
105
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
106
        var_hidden_size: int | None = None,
107
        has_weight: bool = True,
108
        dtype: torch.dtype | None = None,
109
110
    ) -> None:
        super().__init__()
111
112

        self.hidden_size = hidden_size
113
        self.variance_epsilon = eps
114
115
116
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
117
        weight_dtype = dtype or torch.get_default_dtype()
118
        self.has_weight = has_weight
119
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
120
121
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
122
123

        if current_platform.is_rocm():
124
            aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
125
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
126
127
128
                with_fused_add=False,
                dtype=weight_dtype,
                use_aiter=aiter_rmsnorm_enabled,
129
            )
130
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
131
                with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
132
            )
133

134
135
    @staticmethod
    def forward_static(
136
        x: torch.Tensor,
137
138
139
140
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
141
        residual: torch.Tensor | None = None,
142
        variance_size_override: int | None = None,
143
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
144
145
146
        """PyTorch-native implementation equivalent to forward()."""
        x = x.to(torch.float32)
        if residual is not None:
147
148
149
150
            # residual promoted f16->f32 automatically,
            # otherwise Inductor eliminates the casts to and from f16,
            # increasing memory usage (and complicating pattern matching)
            x = x + residual
151
152
            residual = x.to(orig_dtype)

153
        if x.shape[-1] != hidden_size:
154
            raise ValueError(
155
                f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
156
            )
157

158
        if variance_size_override is None:
159
160
            x_var = x
        else:
161
            if hidden_size < variance_size_override:
162
163
                raise ValueError(
                    "Expected hidden_size to be at least "
164
                    f"{variance_size_override}, but found: {hidden_size}"
165
                )
166

167
            x_var = x[:, :, :variance_size_override]
168
169
170

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

171
        x = x * torch.rsqrt(variance + variance_epsilon)
172
        x = x.to(orig_dtype)
173
174
        if weight is not None:
            x = x * weight
175
176
177
178
179
        if residual is None:
            return x
        else:
            return x, residual

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

197
    def forward_cuda(
198
199
        self,
        x: torch.Tensor,
200
201
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
202
203
204
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

205
        add_residual = residual is not None
206
        if add_residual:
207
208
209
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
210
211
212
213
214
215
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)

    def forward_hip(
        self,
        x: torch.Tensor,
216
217
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
218
219
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
220

221
        add_residual = residual is not None
222
        if add_residual:
223
224
225
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
226
        else:
227
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
228

229
230
231
    def forward_xpu(
        self,
        x: torch.Tensor,
232
233
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
234
235
236
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

237
238
239
240
241
242
243
244
245
246
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
247
        return ops.rms_norm(
248
249
250
251
252
            x,
            self.weight.data,
            self.variance_epsilon,
        )

253
254
255
256
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
257
258


259
# --8<-- [start:gemma_rms_norm]
260
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
261
262
263
264
265
266
267
268
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

269
270
    # --8<-- [end:gemma_rms_norm]

Woosuk Kwon's avatar
Woosuk Kwon committed
271
272
273
274
275
276
277
278
279
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

280
281
282
283
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
284
        x: torch.Tensor,
285
286
        residual: torch.Tensor | None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
287
288
289
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
290
291
292
293
294
            x = (
                x.float() + residual.float()
                if orig_dtype == torch.float16
                else x + residual
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
295
296
297
298
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
299
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
300
301
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
302
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
303
304
305
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

306
307
308
    def forward_native(
        self,
        x: torch.Tensor,
309
310
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
311
        """PyTorch-native implementation equivalent to forward()."""
312
        return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)
313

Woosuk Kwon's avatar
Woosuk Kwon committed
314
315
316
    def forward_cuda(
        self,
        x: torch.Tensor,
317
318
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
319
320
321
322
323
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
324
325
                self.forward_static
            )
326
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
327
        return self.forward_native(x, residual)
328
329


330
# --8<-- [start:rms_norm_gated]
331
332
333
334
335
336
337
338
339
340
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
    """RMS Normalization with optional gating.

    This is a native PyTorch implementation that supports:
    - Standard RMS normalization
    - Group RMS normalization
    - Optional gating with SiLU activation
    """

341
342
    # --8<-- [end:rms_norm_gated]

343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        """Initialize RMSNormGated.

        Args:
            hidden_size: Size of the hidden dimension
            eps: Epsilon for numerical stability
            group_size: If not None, do GroupNorm with each group
                        having group_size elements.
                        group_size=None is equivalent to group_size=hidden_size
                        (i.e. there's only 1 group).
            norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                              If False and z is provided: out = norm(x * silu(z))
            device: Device to create parameters on
            dtype: Data type for parameters
        """
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)

    def forward_native(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        Native PyTorch implementation of RMS normalization with gating.

        Args:
            x: Input tensor
            z: Optional gating tensor

        Returns:
            Normalized (and optionally gated) tensor

        If z is not None:
            - norm_before_gate=True: out = norm(x) * silu(z)
            - norm_before_gate=False: out = norm(x * silu(z))
        """
        # Apply gating before normalization if needed
        if z is not None and not self.norm_before_gate:
            x = x * F.silu(z)

        # RMS Normalization
        if self.group_size is None:
            # Standard RMS norm across the last dimension
            variance = x.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x * torch.rsqrt(variance + self.eps)
            out = x_normed * self.weight
        else:
            # Group RMS norm
            from einops import rearrange

            x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
            variance = x_group.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x_group * torch.rsqrt(variance + self.eps)
            out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight

        # Apply gating after normalization if needed
        if z is not None and self.norm_before_gate:
            out = out * F.silu(z)

        return out

    def forward_cuda(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
        )


436
437
438
439
440
441
442
443
444
445
446
447
448
class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
449
450
451
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)