layernorm.py 6.72 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Custom normalization layers."""
3
4
from typing import Optional, Tuple, Union

5
6
7
import torch
import torch.nn as nn

8
from vllm.model_executor.custom_op import CustomOp
9
10


11
@CustomOp.register("rms_norm")
12
class RMSNorm(CustomOp):
13
14
15
16
17
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
18
19
20
21
22

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
23
        var_hidden_size: Optional[int] = None,
24
        has_weight: bool = True,
25
26
    ) -> None:
        super().__init__()
27
28

        self.hidden_size = hidden_size
29
        self.variance_epsilon = eps
30
31
        self.variance_size_override = (None if var_hidden_size == hidden_size
                                       else var_hidden_size)
32
33
34
35
36
        self.has_weight = has_weight

        self.weight = torch.ones(hidden_size)
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
37

38
    def forward_native(
39
40
41
42
43
44
45
46
47
48
49
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        hidden_size = x.shape[-1]
        if hidden_size != self.hidden_size:
            raise ValueError("Expected hidden_size to be "
                             f"{self.hidden_size}, but found: {hidden_size}")

        if self.variance_size_override is None:
            x_var = x
        else:
            if hidden_size < self.variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
                    f"{self.variance_size_override}, but found: {hidden_size}")

            x_var = x[:, :, :self.variance_size_override]

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

67
        x = x * torch.rsqrt(variance + self.variance_epsilon)
68
69
70
        x = x.to(orig_dtype)
        if self.has_weight:
            x = x * self.weight
71
72
73
74
75
        if residual is None:
            return x
        else:
            return x, residual

76
    def forward_cuda(
77
78
79
80
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
81
82
83
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

84
85
        from vllm import _custom_ops as ops

86
        if residual is not None:
87
            ops.fused_add_rms_norm(
88
89
90
91
92
93
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
94
        out = torch.empty_like(x)
95
        ops.rms_norm(
96
97
98
99
100
101
            out,
            x,
            self.weight.data,
            self.variance_epsilon,
        )
        return out
102

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    def forward_hpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        from vllm_hpu_extension.ops import HPUFusedRMSNorm
        if HPUFusedRMSNorm is None:
            return self.forward_native(x, residual)
        if residual is not None:
            orig_shape = x.shape
            residual += x.view(residual.shape)
            # Note: HPUFusedRMSNorm requires 3D tensors as inputs
            x = HPUFusedRMSNorm.apply(residual, self.weight,
                                      self.variance_epsilon)
            return x.view(orig_shape), residual

        x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
        return x

122
123
124
125
126
    def forward_xpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
127
128
129
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

130
131
132
133
134
135
136
137
138
139
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
140
        return ops.rms_norm(
141
142
143
144
145
            x,
            self.weight.data,
            self.variance_epsilon,
        )

146
147
148
149
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
150
151


152
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

170
171
172
173
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
174
        x: torch.Tensor,
175
        residual: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
176
177
178
179
180
181
182
183
184
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
            x = x + residual
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
185
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
186
187
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
188
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
189
190
191
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

192
193
194
195
196
197
198
199
200
    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        return self.forward_static(self.weight.data, self.variance_epsilon, x,
                                   residual)

Woosuk Kwon's avatar
Woosuk Kwon committed
201
202
203
204
205
    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
206
207
208
209
210
211
212
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
                self.forward_static)
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
213
        return self.forward_native(x, residual)