layernorm.py 2.36 KB
Newer Older
1
"""Custom normalization layers."""
2
3
from typing import Optional, Tuple, Union

4
5
6
import torch
import torch.nn as nn

7
from vllm._C import ops
8
9


Roy's avatar
Roy committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class LayerNorm(nn.LayerNorm):

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__(hidden_size, eps=eps)

    def forward(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """normalization."""
        if residual is not None:
            x = x + residual
            residual = x
        x = super().forward(x)
        if residual is None:
            return x
        else:
            return x, residual


35
class RMSNorm(nn.Module):
36
37
38
39
40
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
41
42
43
44
45
46
47
48
49
50

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    def _forward(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.variance_epsilon)
        x = x.to(orig_dtype) * self.weight
        if residual is None:
            return x
        else:
            return x, residual

71
72
73
74
75
76
    def forward(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        if residual is not None:
77
            ops.fused_add_rms_norm(
78
79
80
81
82
83
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
84
        out = torch.empty_like(x)
85
        ops.rms_norm(
86
87
88
89
90
91
            out,
            x,
            self.weight.data,
            self.variance_epsilon,
        )
        return out