layernorm.py 8.24 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4
from typing import Optional, Union
5

6
7
8
import torch
import torch.nn as nn

9
import vllm.envs as envs
10
from vllm.model_executor.custom_op import CustomOp
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from vllm.platforms import current_platform


def is_rocm_aiter_rmsnorm_enabled() -> bool:
    return current_platform.is_rocm() \
        and envs.VLLM_ROCM_USE_AITER_RMSNORM \
        and envs.VLLM_ROCM_USE_AITER


def rms_norm(x: torch.Tensor, weight: torch.Tensor,
             variance_epsilon: float) -> torch.Tensor:
    from vllm import _custom_ops as ops
    out = torch.empty_like(x)
    ops.rms_norm(
        out,
        x,
        weight,
        variance_epsilon,
    )
    return out


def fused_add_rms_norm(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
35
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
36
37
38
39
40
41
42
43
44
45
46
47
48
    from vllm import _custom_ops as ops
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
                        variance_epsilon: float) -> torch.Tensor:
    import aiter as rocm_aiter
49
50
51
52
53
54
    if x.dim() > 2:
        x_original_shape = x.shape
        x = x.reshape(-1, x_original_shape[-1])
        x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
        return x.reshape(x_original_shape)

55
56
57
58
59
    return rocm_aiter.rms_norm(x, weight, variance_epsilon)


def rocm_aiter_fused_add_rms_norm(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
60
        variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
61
62
63

    import aiter as rocm_aiter

64
65
    residual_out = torch.empty_like(residual)
    output = torch.empty_like(x)
66
    rocm_aiter.rmsnorm2d_fwd_with_add(
67
        output,  # output
68
69
        x,  # input
        residual,  # residual input
70
        residual_out,  # residual output
71
72
73
        weight,
        variance_epsilon,
    )
74
    return output, residual_out
75
76
77
78
79
80
81
82
83
84
85


def dispatch_cuda_rmsnorm_func(add_residual: bool):
    if add_residual:
        if is_rocm_aiter_rmsnorm_enabled():
            return rocm_aiter_fused_add_rms_norm
        return fused_add_rms_norm

    if is_rocm_aiter_rmsnorm_enabled():
        return rocm_aiter_rms_norm
    return rms_norm
86
87


88
@CustomOp.register("rms_norm")
89
class RMSNorm(CustomOp):
90
91
92
93
94
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
95
96
97
98
99

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
100
        var_hidden_size: Optional[int] = None,
101
        has_weight: bool = True,
102
        dtype: Optional[torch.dtype] = None,
103
104
    ) -> None:
        super().__init__()
105
106

        self.hidden_size = hidden_size
107
        self.variance_epsilon = eps
108
109
        self.variance_size_override = (None if var_hidden_size == hidden_size
                                       else var_hidden_size)
110
        self.has_weight = has_weight
111
112
113
114
        if dtype is not None:
            self.weight = torch.ones(hidden_size, dtype=dtype)
        else:
            self.weight = torch.ones(hidden_size)
115
116
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
117

118
    def forward_native(
119
120
121
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
122
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
123
124
125
126
127
128
129
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        hidden_size = x.shape[-1]
        if hidden_size != self.hidden_size:
            raise ValueError("Expected hidden_size to be "
                             f"{self.hidden_size}, but found: {hidden_size}")

        if self.variance_size_override is None:
            x_var = x
        else:
            if hidden_size < self.variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
                    f"{self.variance_size_override}, but found: {hidden_size}")

            x_var = x[:, :, :self.variance_size_override]

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

147
        x = x * torch.rsqrt(variance + self.variance_epsilon)
148
149
150
        x = x.to(orig_dtype)
        if self.has_weight:
            x = x * self.weight
151
152
153
154
155
        if residual is None:
            return x
        else:
            return x, residual

156
    def forward_cuda(
157
158
159
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
160
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
161
162
163
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

164
165
        add_residual = residual is not None
        norm_func = dispatch_cuda_rmsnorm_func(add_residual)
166

167
168
169
170
171
        if add_residual:
            return norm_func(x, residual, self.weight.data,
                             self.variance_epsilon)
        else:
            return norm_func(x, self.weight.data, self.variance_epsilon)
172

173
174
175
176
    def forward_xpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
177
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
178
179
180
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

181
182
183
184
185
186
187
188
189
190
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
191
        return ops.rms_norm(
192
193
194
195
196
            x,
            self.weight.data,
            self.variance_epsilon,
        )

197
198
199
200
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
201
202


203
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

221
222
223
224
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
225
        x: torch.Tensor,
226
        residual: Optional[torch.Tensor],
227
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
228
229
230
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
231
232
233
234
            if orig_dtype == torch.float16:
                x = x + residual.float()
            else:
                x = x + residual
Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
237
238
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
239
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
240
241
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
242
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
243
244
245
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

246
247
248
249
    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
250
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
251
252
253
254
        """PyTorch-native implementation equivalent to forward()."""
        return self.forward_static(self.weight.data, self.variance_epsilon, x,
                                   residual)

Woosuk Kwon's avatar
Woosuk Kwon committed
255
256
257
258
    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
259
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
260
261
262
263
264
265
266
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
                self.forward_static)
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
267
        return self.forward_native(x, residual)