layernorm.py 5.2 KB
Newer Older
1
"""Custom normalization layers."""
2
3
from typing import Optional, Tuple, Union

4
5
6
import torch
import torch.nn as nn

7
from vllm.model_executor.custom_op import CustomOp
zhuwenwen's avatar
zhuwenwen committed
8
import vllm.envs as envs
9
10


11
class RMSNorm(CustomOp):
12
13
14
15
16
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
17
18
19
20
21
22
23
24
25
26

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

27
    def forward_native(
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.variance_epsilon)
        x = x.to(orig_dtype) * self.weight
        if residual is None:
            return x
        else:
            return x, residual

47
    def forward_cuda(
48
49
50
51
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
52
53
        from vllm import _custom_ops as ops

54
        if residual is not None:
zhuwenwen's avatar
zhuwenwen committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
            if envs.VLLM_USE_OPT_OP:
                ops.fused_add_rms_norm_opt(
                    x,
                    residual,
                    self.weight.data,
                    self.variance_epsilon,
                )
            else:
                ops.fused_add_rms_norm(
                    x,
                    residual,
                    self.weight.data,
                    self.variance_epsilon,
                )
            return x, residual
        out = torch.empty_like(x)
        if envs.VLLM_USE_OPT_OP:
            ops.rms_norm_opt(
                out,
                x,
                self.weight.data,
                self.variance_epsilon,
            )
        else:
            ops.rms_norm(
                out,
81
82
83
84
                x,
                self.weight.data,
                self.variance_epsilon,
            )
85
        return out
86

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    def forward_xpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
102
        return ops.rms_norm(
103
104
105
106
107
            x,
            self.weight.data,
            self.variance_epsilon,
        )

108
109
110
111
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130


class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

131
132
133
134
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
135
        x: torch.Tensor,
136
        residual: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
137
138
139
140
141
142
143
144
145
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
            x = x + residual
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
146
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
149
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
150
151
152
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

153
154
155
156
157
158
159
160
161
    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        return self.forward_static(self.weight.data, self.variance_epsilon, x,
                                   residual)

Woosuk Kwon's avatar
Woosuk Kwon committed
162
163
164
165
166
    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
167
168
169
170
171
172
173
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
                self.forward_static)
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
174
        return self.forward_native(x, residual)