layernorm.py 15 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
import vllm.envs as envs
10
from vllm.model_executor.custom_op import CustomOp
11
12
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
13
    vllm_is_batch_invariant,
14
)
15
from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn
16
from vllm.platforms import current_platform
17
from vllm.utils.torch_utils import direct_register_custom_op
18
19
20


def is_rocm_aiter_rmsnorm_enabled() -> bool:
21
    return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER
22
23


24
25
26
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
27
    from vllm import _custom_ops as ops
28

29
    if vllm_is_batch_invariant():
30
        return rms_norm_batch_invariant(x, weight, variance_epsilon)
31
32
33
34
35
36
37
38
39
40
41
    out = torch.empty_like(x)
    ops.rms_norm(
        out,
        x,
        weight,
        variance_epsilon,
    )
    return out


def fused_add_rms_norm(
42
43
44
45
46
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
47
    from vllm import _custom_ops as ops
48

49
    if vllm_is_batch_invariant():
50
51
52
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
53
54
55
56
57
58
59
60
61
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


62
63
64
def rocm_aiter_rms_norm_impl(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
65
    import aiter as rocm_aiter
66

67
68
69
70
71
72
    if x.dim() > 2:
        x_original_shape = x.shape
        x = x.reshape(-1, x_original_shape[-1])
        x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
        return x.reshape(x_original_shape)

73
74
75
    return rocm_aiter.rms_norm(x, weight, variance_epsilon)


76
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
77
78
79
80
81
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
82
83
    import aiter as rocm_aiter

84
85
    residual_out = torch.empty_like(residual)
    output = torch.empty_like(x)
86
    rocm_aiter.rmsnorm2d_fwd_with_add(
87
        output,  # output
88
89
        x,  # input
        residual,  # residual input
90
        residual_out,  # residual output
91
92
93
        weight,
        variance_epsilon,
    )
94
    return output, residual_out
95
96


97
98
99
def rocm_aiter_rms_norm_fake(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
100
101
102
103
    return torch.empty_like(x)


def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
104
105
106
107
108
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    return torch.empty_like(x), torch.empty_like(residual)


if current_platform.is_rocm():
    direct_register_custom_op(
        op_name="rocm_aiter_rms_norm",
        op_func=rocm_aiter_rms_norm_impl,
        fake_impl=rocm_aiter_rms_norm_fake,
    )

    direct_register_custom_op(
        op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
        op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
        fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
    )


def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
    use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
128
129
        torch.float16,
        torch.bfloat16,
130
131
132
133
134
135
    ]

    if use_aiter and with_fused_add:
        return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
    if use_aiter:
        return torch.ops.vllm.rocm_aiter_rms_norm
136

137
138
139
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
140
    return rms_norm
141
142


143
@CustomOp.register("rms_norm")
144
class RMSNorm(CustomOp):
145
146
147
148
149
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
150
151
152
153
154

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
155
        var_hidden_size: int | None = None,
156
        has_weight: bool = True,
157
        dtype: torch.dtype | None = None,
158
159
    ) -> None:
        super().__init__()
160
161

        self.hidden_size = hidden_size
162
        self.variance_epsilon = eps
163
164
165
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
166
        weight_dtype = dtype or torch.get_default_dtype()
167
        self.has_weight = has_weight
168
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
169
170
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
171
172
173

        if current_platform.is_rocm():
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
174
175
                with_fused_add=False, dtype=weight_dtype
            )
176
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
177
178
                with_fused_add=True, dtype=weight_dtype
            )
179

180
181
    @staticmethod
    def forward_static(
182
        x: torch.Tensor,
183
184
185
186
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
187
        residual: torch.Tensor | None = None,
188
        variance_size_override: int | None = None,
189
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
190
191
192
        """PyTorch-native implementation equivalent to forward()."""
        x = x.to(torch.float32)
        if residual is not None:
193
194
195
196
            # residual promoted f16->f32 automatically,
            # otherwise Inductor eliminates the casts to and from f16,
            # increasing memory usage (and complicating pattern matching)
            x = x + residual
197
198
            residual = x.to(orig_dtype)

199
        if x.shape[-1] != hidden_size:
200
            raise ValueError(
201
                f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
202
            )
203

204
        if variance_size_override is None:
205
206
            x_var = x
        else:
207
            if hidden_size < variance_size_override:
208
209
                raise ValueError(
                    "Expected hidden_size to be at least "
210
                    f"{variance_size_override}, but found: {hidden_size}"
211
                )
212

213
            x_var = x[:, :, :variance_size_override]
214
215
216

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

217
        x = x * torch.rsqrt(variance + variance_epsilon)
218
        x = x.to(orig_dtype)
219
220
        if weight is not None:
            x = x * weight
221
222
223
224
225
        if residual is None:
            return x
        else:
            return x, residual

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

243
    def forward_cuda(
244
245
        self,
        x: torch.Tensor,
246
247
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
248
249
250
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

251
        add_residual = residual is not None
252
        if add_residual:
253
254
255
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
256
257
258
259
260
261
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)

    def forward_hip(
        self,
        x: torch.Tensor,
262
263
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
264
265
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
266

267
        add_residual = residual is not None
268
        if add_residual:
269
270
271
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
272
        else:
273
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
274

275
276
277
    def forward_xpu(
        self,
        x: torch.Tensor,
278
279
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
280
281
282
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

283
284
285
286
287
288
289
290
291
292
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
293
        return ops.rms_norm(
294
295
296
297
298
            x,
            self.weight.data,
            self.variance_epsilon,
        )

299
300
301
302
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
303
304


305
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

323
324
325
326
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
327
        x: torch.Tensor,
328
329
        residual: torch.Tensor | None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
330
331
332
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
333
334
335
336
337
            x = (
                x.float() + residual.float()
                if orig_dtype == torch.float16
                else x + residual
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
338
339
340
341
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
342
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
343
344
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
345
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
346
347
348
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

349
350
351
    def forward_native(
        self,
        x: torch.Tensor,
352
353
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
354
        """PyTorch-native implementation equivalent to forward()."""
355
        return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)
356

Woosuk Kwon's avatar
Woosuk Kwon committed
357
358
359
    def forward_cuda(
        self,
        x: torch.Tensor,
360
361
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
362
363
364
365
366
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
367
368
                self.forward_static
            )
369
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
370
        return self.forward_native(x, residual)
371
372


373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
    """RMS Normalization with optional gating.

    This is a native PyTorch implementation that supports:
    - Standard RMS normalization
    - Group RMS normalization
    - Optional gating with SiLU activation
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        """Initialize RMSNormGated.

        Args:
            hidden_size: Size of the hidden dimension
            eps: Epsilon for numerical stability
            group_size: If not None, do GroupNorm with each group
                        having group_size elements.
                        group_size=None is equivalent to group_size=hidden_size
                        (i.e. there's only 1 group).
            norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                              If False and z is provided: out = norm(x * silu(z))
            device: Device to create parameters on
            dtype: Data type for parameters
        """
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)

    def forward_native(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        Native PyTorch implementation of RMS normalization with gating.

        Args:
            x: Input tensor
            z: Optional gating tensor

        Returns:
            Normalized (and optionally gated) tensor

        If z is not None:
            - norm_before_gate=True: out = norm(x) * silu(z)
            - norm_before_gate=False: out = norm(x * silu(z))
        """
        # Apply gating before normalization if needed
        if z is not None and not self.norm_before_gate:
            x = x * F.silu(z)

        # RMS Normalization
        if self.group_size is None:
            # Standard RMS norm across the last dimension
            variance = x.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x * torch.rsqrt(variance + self.eps)
            out = x_normed * self.weight
        else:
            # Group RMS norm
            from einops import rearrange

            x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
            variance = x_group.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x_group * torch.rsqrt(variance + self.eps)
            out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight

        # Apply gating after normalization if needed
        if z is not None and self.norm_before_gate:
            out = out * F.silu(z)

        return out

    def forward_cuda(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
        )


474
475
476
477
478
479
480
481
482
483
484
485
486
class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
487
488
489
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)