layernorm.py 13.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
from vllm._aiter_ops import rocm_aiter_ops
10
from vllm.model_executor.custom_op import CustomOp
11
12
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
13
    vllm_is_batch_invariant,
14
)
15
16
17
from vllm.platforms import current_platform


18
19
20
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
21
    from vllm import _custom_ops as ops
22

23
    if vllm_is_batch_invariant():
24
        return rms_norm_batch_invariant(x, weight, variance_epsilon)
25
26
27
28
29
30
31
32
33
34
35
    out = torch.empty_like(x)
    ops.rms_norm(
        out,
        x,
        weight,
        variance_epsilon,
    )
    return out


def fused_add_rms_norm(
36
37
38
39
40
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
41
    from vllm import _custom_ops as ops
42

43
    if vllm_is_batch_invariant():
44
45
46
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
47
48
49
50
51
52
53
54
55
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


56
57
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
58
) -> torch.Tensor:
59
    from vllm import _custom_ops as ops
60

61
62
63
64
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
65
        weight,
66
        bias,
67
68
        variance_epsilon,
    )
69
    return out
70
71


72
73
74
75
def dispatch_rocm_rmsnorm_func(
    with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
    use_aiter = use_aiter and dtype in [
76
77
        torch.float16,
        torch.bfloat16,
78
79
80
    ]

    if use_aiter and with_fused_add:
81
        return rocm_aiter_ops.rms_norm2d_with_add
82
    if use_aiter:
83
        return rocm_aiter_ops.rms_norm
84

85
86
87
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
88
    return rms_norm
89
90


91
@CustomOp.register("rms_norm")
92
class RMSNorm(CustomOp):
93
94
95
96
97
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
98
99
100
101
102

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
103
        var_hidden_size: int | None = None,
104
        has_weight: bool = True,
105
        dtype: torch.dtype | None = None,
106
107
    ) -> None:
        super().__init__()
108
109

        self.hidden_size = hidden_size
110
        self.variance_epsilon = eps
111
112
113
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
114
        weight_dtype = dtype or torch.get_default_dtype()
115
        self.has_weight = has_weight
116
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
117
118
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
119
120

        if current_platform.is_rocm():
121
            aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
122
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
123
124
125
                with_fused_add=False,
                dtype=weight_dtype,
                use_aiter=aiter_rmsnorm_enabled,
126
            )
127
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
128
                with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
129
            )
130

131
132
    @staticmethod
    def forward_static(
133
        x: torch.Tensor,
134
135
136
137
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
138
        residual: torch.Tensor | None = None,
139
        variance_size_override: int | None = None,
140
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
141
142
143
        """PyTorch-native implementation equivalent to forward()."""
        x = x.to(torch.float32)
        if residual is not None:
144
145
146
147
            # residual promoted f16->f32 automatically,
            # otherwise Inductor eliminates the casts to and from f16,
            # increasing memory usage (and complicating pattern matching)
            x = x + residual
148
149
            residual = x.to(orig_dtype)

150
        if x.shape[-1] != hidden_size:
151
            raise ValueError(
152
                f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
153
            )
154

155
        if variance_size_override is None:
156
157
            x_var = x
        else:
158
            if hidden_size < variance_size_override:
159
160
                raise ValueError(
                    "Expected hidden_size to be at least "
161
                    f"{variance_size_override}, but found: {hidden_size}"
162
                )
163

164
            x_var = x[:, :, :variance_size_override]
165
166
167

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

168
        x = x * torch.rsqrt(variance + variance_epsilon)
169
        x = x.to(orig_dtype)
170
171
        if weight is not None:
            x = x * weight
172
173
174
175
176
        if residual is None:
            return x
        else:
            return x, residual

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

194
    def forward_cuda(
195
196
        self,
        x: torch.Tensor,
197
198
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
199
200
201
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

202
        add_residual = residual is not None
203
        if add_residual:
204
205
206
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
207
208
209
210
211
212
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)

    def forward_hip(
        self,
        x: torch.Tensor,
213
214
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
215
216
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
217

218
        add_residual = residual is not None
219
        if add_residual:
220
221
222
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
223
        else:
224
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
225

226
227
228
    def forward_xpu(
        self,
        x: torch.Tensor,
229
230
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
231
232
233
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

234
235
236
237
238
239
240
241
242
243
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
244
        return ops.rms_norm(
245
246
247
248
249
            x,
            self.weight.data,
            self.variance_epsilon,
        )

250
251
252
253
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
254
255


256
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

274
275
276
277
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
278
        x: torch.Tensor,
279
280
        residual: torch.Tensor | None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
281
282
283
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
284
285
286
287
288
            x = (
                x.float() + residual.float()
                if orig_dtype == torch.float16
                else x + residual
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
289
290
291
292
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
293
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
294
295
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
296
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
297
298
299
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

300
301
302
    def forward_native(
        self,
        x: torch.Tensor,
303
304
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
305
        """PyTorch-native implementation equivalent to forward()."""
306
        return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)
307

Woosuk Kwon's avatar
Woosuk Kwon committed
308
309
310
    def forward_cuda(
        self,
        x: torch.Tensor,
311
312
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
313
314
315
316
317
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
318
319
                self.forward_static
            )
320
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
321
        return self.forward_native(x, residual)
322
323


324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
    """RMS Normalization with optional gating.

    This is a native PyTorch implementation that supports:
    - Standard RMS normalization
    - Group RMS normalization
    - Optional gating with SiLU activation
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        """Initialize RMSNormGated.

        Args:
            hidden_size: Size of the hidden dimension
            eps: Epsilon for numerical stability
            group_size: If not None, do GroupNorm with each group
                        having group_size elements.
                        group_size=None is equivalent to group_size=hidden_size
                        (i.e. there's only 1 group).
            norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                              If False and z is provided: out = norm(x * silu(z))
            device: Device to create parameters on
            dtype: Data type for parameters
        """
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)

    def forward_native(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        Native PyTorch implementation of RMS normalization with gating.

        Args:
            x: Input tensor
            z: Optional gating tensor

        Returns:
            Normalized (and optionally gated) tensor

        If z is not None:
            - norm_before_gate=True: out = norm(x) * silu(z)
            - norm_before_gate=False: out = norm(x * silu(z))
        """
        # Apply gating before normalization if needed
        if z is not None and not self.norm_before_gate:
            x = x * F.silu(z)

        # RMS Normalization
        if self.group_size is None:
            # Standard RMS norm across the last dimension
            variance = x.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x * torch.rsqrt(variance + self.eps)
            out = x_normed * self.weight
        else:
            # Group RMS norm
            from einops import rearrange

            x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
            variance = x_group.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x_group * torch.rsqrt(variance + self.eps)
            out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight

        # Apply gating after normalization if needed
        if z is not None and self.norm_before_gate:
            out = out * F.silu(z)

        return out

    def forward_cuda(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
        )


427
428
429
430
431
432
433
434
435
436
437
438
439
class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
440
441
442
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)