activation.py 14.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4
import math
5
6
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch
import torch.nn as nn
9
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
13
from vllm.model_executor.custom_op import CustomOp
14
from vllm.model_executor.utils import set_weight_attrs
15
from vllm.platforms import current_platform
16
from vllm.utils import LazyDict
zhuwenwen's avatar
zhuwenwen committed
17
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19


20
@CustomOp.register("fatrelu_and_mul")
21
22
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
23

24
25
26
27
28
29
30
31
32
33
34
35
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold
36
        if current_platform.is_cuda_alike():
37
            self.op = torch.ops._C.fatrelu_and_mul
38
39
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
40
41
42
43
44
45
46
47
48

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
49
50
51
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
52
        self.op(out, x, self.threshold)
53
        return out
54
55


56
@CustomOp.register("silu_and_mul")
57
class SiluAndMul(CustomOp):
58
59
    """An activation function for SwiGLU.

60
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
61

62
    Shapes:
63
64
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
65
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
66

67
68
    def __init__(self):
        super().__init__()
69
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
70
            self.op = torch.ops._C.silu_and_mul
zhuwenwen's avatar
zhuwenwen committed
71
            self.op_opt = torch.ops._C.silu_and_mul_opt
72
        elif current_platform.is_xpu():
73
74
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
75

76
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
77
78
79
80
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

81
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
82
83
84
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
85
        if envs.VLLM_USE_OPT_OP:
zhuwenwen's avatar
zhuwenwen committed
86
            self.op_opt(out, x)
zhuwenwen's avatar
zhuwenwen committed
87
        else:
zhuwenwen's avatar
zhuwenwen committed
88
            self.op(out, x) 
Woosuk Kwon's avatar
Woosuk Kwon committed
89
        return out
90

91
92
93
94
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
95
        self.op(out, x)
96
97
        return out

98
99
100
101
102
103
104
    def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x_reshaped = x.view(-1, x.shape[-1])
        s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
        result = s * x_reshaped[:, d:]
        return result.view(*x.shape[:-1], d)

105

106
107
108
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.
109

110
111
112
113
114
115
116
117
118
    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self):
        super().__init__()
119
        if current_platform.is_cuda_alike():
120
121
122
123
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
124
125
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
126
127
128
129
130
131
132

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
133
134
135
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
136
        self.op(out, x)
137
138
        return out

139
140
141
    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

142

Robert Shaw's avatar
Robert Shaw committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
    """An activation function for GeluAndMulSparse.
    This activation function is used in Gemma3n. It computes:
        up_proj = self.up_proj(x)
        gate_proj = self.gate_proj(x)
        gate_proj = self._gaussian_topk(gate_proj) # sparsity
        activations = self.act_fn(gate_proj) # gelu
        down_proj = self.down_proj(activations * up_proj)
    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, activation_sparsity: float, approximate: str = "none"):
        super().__init__()
        # Gelu.
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")

        # Sparsity.
        if activation_sparsity == 0.0:
            raise ValueError(
                "activation_sparsity is 0.0. Please use GeluAndMul.")
        target_sparsity_tensor = torch.tensor(activation_sparsity,
                                              dtype=torch.float32)
        normal_dist = torch.distributions.normal.Normal(0, 1)
        self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)

    def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
        """Get % sparse percentile of the Gaussian distribution."""
        # NOTE(rob): for TP>1, we could all-gather to get the means/std.
        # But we do not do this because in expectation they are the same
        # and in practice the eval scores are good without gathering.
        mean = torch.mean(x, dim=-1, keepdim=True)
        std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
        cutoff_x = mean + std * self.std_multiplier
        return nn.functional.relu(x - cutoff_x)

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        out = self._gaussian_topk(x[..., :d])
        out = F.gelu(out, approximate=self.approximate)
        return out * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


194
@CustomOp.register("gelu_and_mul")
195
class GeluAndMul(CustomOp):
196
197
198
199
200
201
202
203
204
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

205
206
207
208
209
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
210
211
212
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
zhuwenwen's avatar
zhuwenwen committed
213
                self.op_opt = torch.ops._C.gelu_and_mul_opt
214
215
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
zhuwenwen's avatar
zhuwenwen committed
216
                self.op_opt = torch.ops._C.gelu_tanh_and_mul_opt
217
218
219
220
221
222
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
223

224
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
225
226
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
227
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
228

229
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
230
231
232
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
233
234
235
236
        if envs.VLLM_USE_OPT_OP: 
            self.op_opt(out, x)
        else:
            self.op(out, x)
237
238
        return out

239
240
241
242
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
243
        self.op(out, x)
244
245
        return out

246
247
248
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

249

250
@CustomOp.register("gelu_new")
251
class NewGELU(CustomOp):
252

253
254
255
256
257
258
259
260
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_new

261
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
262
263
264
265
266
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

267
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
268
        out = torch.empty_like(x)
269
        self.op(out, x)
270
271
        return out

272
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
273
        return self.op(x)
274

275

276
@CustomOp.register("gelu_fast")
277
class FastGELU(CustomOp):
278

279
280
281
282
283
284
285
286
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_fast

287
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
288
289
290
291
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

292
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
293
        out = torch.empty_like(x)
294
        self.op(out, x)
295
296
        return out

297
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
298
        return self.op(x)
299

300

301
@CustomOp.register("quick_gelu")
302
303
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
304
305
306
307
308
309
310
311
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_quick

312
313
314
315
316
317
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
318
        self.op(out, x)
319
320
        return out

321
322
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
323
        self.op(out, x)
324
325
        return out

326
327
328
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

329

330
@CustomOp.register("relu2")
331
332
333
334
335
336
337
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
338
        return torch.square(F.relu(x))
339
340
341
342
343

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


344
345
346
347
348
349
350
351
352
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
353
354
355
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
356
357
358
    ):
        super().__init__()
        self.act = act_module
359
        self.input_is_parallel = input_is_parallel
360
361
362
363
364
365
366
367
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
368
        self.scales = nn.Parameter(
369
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
370
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
371

372
    def forward(self, x: torch.Tensor) -> torch.Tensor:
373
374
        return self.act(x) / self.scales

375
376
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
377
378
379
380
381
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
382
383
384
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

385

386
387
388
389
390
391
392
393
394
395
396
397
398
_ACTIVATION_REGISTRY = LazyDict({
    "gelu":
    lambda: nn.GELU(),
    "gelu_fast":
    lambda: FastGELU(),
    "gelu_new":
    lambda: NewGELU(),
    "gelu_pytorch_tanh":
    lambda: nn.GELU(approximate="tanh"),
    "relu":
    lambda: nn.ReLU(),
    "relu2":
    lambda: ReLUSquaredActivation(),
399
400
    "silu":
    lambda: nn.SiLU(),
401
402
403
    "quick_gelu":
    lambda: QuickGELU(),
})
404
405


406
def get_act_fn(act_fn_name: str) -> nn.Module:
407
    """Get an activation function by name."""
408
409
410
411
412
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

413
    return _ACTIVATION_REGISTRY[act_fn_name]
414
415
416
417
418


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
    "gelu": lambda: GeluAndMul(),
    "silu": lambda: SiluAndMul(),
419
    "geglu": lambda: GeluAndMul(),
420
421
422
})


423
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
424
425
426
427
428
429
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

430
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]