activation.py 15.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4
import math
5
6
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch
import torch.nn as nn
9
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
13
from vllm.model_executor.custom_op import CustomOp
14
from vllm.model_executor.utils import set_weight_attrs
15
from vllm.platforms import current_platform
16
from vllm.utils import LazyDict
zhuwenwen's avatar
zhuwenwen committed
17
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19


20
@CustomOp.register("fatrelu_and_mul")
21
22
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
23

24
25
26
27
28
29
30
31
32
33
34
35
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold
36
        if current_platform.is_cuda_alike():
37
            self.op = torch.ops._C.fatrelu_and_mul
38
39
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
40
41
42
43
44
45
46
47
48

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
49
50
51
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
52
        self.op(out, x, self.threshold)
53
        return out
54
55


56
@CustomOp.register("silu_and_mul")
57
class SiluAndMul(CustomOp):
58
59
    """An activation function for SwiGLU.

60
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
61

62
    Shapes:
63
64
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
65
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
66

67
68
    def __init__(self):
        super().__init__()
69
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
70
            self.op = torch.ops._C.silu_and_mul
zhuwenwen's avatar
zhuwenwen committed
71
            self.op_opt = torch.ops._C.silu_and_mul_opt
72
        elif current_platform.is_xpu():
73
74
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
75

76
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
77
        """PyTorch-native implementation equivalent to forward()."""
zhuwenwen's avatar
zhuwenwen committed
78
        if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO:
79
80
            return self.forward_cuda(x)
        elif envs.VLLM_USE_OPT_OP:
maxiao1's avatar
maxiao1 committed
81
            return self.forward_cuda(x)
maxiao1's avatar
maxiao1 committed
82
83
84
        else:
            d = x.shape[-1] // 2
            return F.silu(x[..., :d]) * x[..., d:]
85

86
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
87
88
89
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
90
        if envs.VLLM_USE_OPT_OP:
zhuwenwen's avatar
zhuwenwen committed
91
            self.op_opt(out, x)
zhuwenwen's avatar
zhuwenwen committed
92
        else:
zhuwenwen's avatar
zhuwenwen committed
93
            self.op(out, x) 
Woosuk Kwon's avatar
Woosuk Kwon committed
94
        return out
95

96
97
98
99
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
100
        self.op(out, x)
101
102
        return out

103
104
105
106
107
108
109
    def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x_reshaped = x.view(-1, x.shape[-1])
        s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
        result = s * x_reshaped[:, d:]
        return result.view(*x.shape[:-1], d)

110

111
112
113
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.
114

115
116
117
118
119
120
121
122
123
    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self):
        super().__init__()
124
        if current_platform.is_cuda_alike():
125
126
127
128
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
129
130
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
131
132
133
134
135
136
137

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
138
139
140
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
141
        self.op(out, x)
142
143
        return out

144
145
146
    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

147

Robert Shaw's avatar
Robert Shaw committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
    """An activation function for GeluAndMulSparse.
    This activation function is used in Gemma3n. It computes:
        up_proj = self.up_proj(x)
        gate_proj = self.gate_proj(x)
        gate_proj = self._gaussian_topk(gate_proj) # sparsity
        activations = self.act_fn(gate_proj) # gelu
        down_proj = self.down_proj(activations * up_proj)
    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, activation_sparsity: float, approximate: str = "none"):
        super().__init__()
        # Gelu.
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")

        # Sparsity.
        if activation_sparsity == 0.0:
            raise ValueError(
                "activation_sparsity is 0.0. Please use GeluAndMul.")
        target_sparsity_tensor = torch.tensor(activation_sparsity,
                                              dtype=torch.float32)
        normal_dist = torch.distributions.normal.Normal(0, 1)
        self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)

    def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
        """Get % sparse percentile of the Gaussian distribution."""
        # NOTE(rob): for TP>1, we could all-gather to get the means/std.
        # But we do not do this because in expectation they are the same
        # and in practice the eval scores are good without gathering.
        mean = torch.mean(x, dim=-1, keepdim=True)
        std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
        cutoff_x = mean + std * self.std_multiplier
        return nn.functional.relu(x - cutoff_x)

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        out = self._gaussian_topk(x[..., :d])
        out = F.gelu(out, approximate=self.approximate)
        return out * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


199
@CustomOp.register("gelu_and_mul")
200
class GeluAndMul(CustomOp):
201
202
203
204
205
206
207
208
209
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

210
211
212
213
214
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
215
216
217
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
zhuwenwen's avatar
zhuwenwen committed
218
                self.op_opt = torch.ops._C.gelu_and_mul_opt
219
220
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
zhuwenwen's avatar
zhuwenwen committed
221
                self.op_opt = torch.ops._C.gelu_tanh_and_mul_opt
222
223
224
225
226
227
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
228

229
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
230
231
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
232
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
233

234
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
235
236
237
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
238
239
240
241
        if envs.VLLM_USE_OPT_OP: 
            self.op_opt(out, x)
        else:
            self.op(out, x)
242
243
        return out

244
245
246
247
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
248
        self.op(out, x)
249
250
        return out

251
252
253
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

254

255
@CustomOp.register("gelu_new")
256
class NewGELU(CustomOp):
257

258
259
260
261
262
263
264
265
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_new

266
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
267
268
269
270
271
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

272
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
273
        out = torch.empty_like(x)
274
        self.op(out, x)
275
276
        return out

277
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
278
        return self.op(x)
279

280

281
@CustomOp.register("gelu_fast")
282
class FastGELU(CustomOp):
283

284
285
286
287
288
289
290
291
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_fast

292
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
293
294
295
296
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

297
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
298
        out = torch.empty_like(x)
299
        self.op(out, x)
300
301
        return out

302
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
303
        return self.op(x)
304

305

306
@CustomOp.register("quick_gelu")
307
308
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
309
310
311
312
313
314
315
316
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_quick

317
318
319
320
321
322
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
323
        self.op(out, x)
324
325
        return out

326
327
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
328
        self.op(out, x)
329
330
        return out

331
332
333
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

334

335
@CustomOp.register("relu2")
336
337
338
339
340
341
342
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
343
        return torch.square(F.relu(x))
344
345
346
347
348

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


349
350
351
352
353
354
355
356
357
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
358
359
360
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
361
362
363
    ):
        super().__init__()
        self.act = act_module
364
        self.input_is_parallel = input_is_parallel
365
366
367
368
369
370
371
372
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
373
        self.scales = nn.Parameter(
374
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
375
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
376

377
    def forward(self, x: torch.Tensor) -> torch.Tensor:
378
379
        return self.act(x) / self.scales

380
381
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
382
383
384
385
386
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
387
388
389
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

390

391
392
393
394
395
396
397
398
399
400
401
402
403
_ACTIVATION_REGISTRY = LazyDict({
    "gelu":
    lambda: nn.GELU(),
    "gelu_fast":
    lambda: FastGELU(),
    "gelu_new":
    lambda: NewGELU(),
    "gelu_pytorch_tanh":
    lambda: nn.GELU(approximate="tanh"),
    "relu":
    lambda: nn.ReLU(),
    "relu2":
    lambda: ReLUSquaredActivation(),
404
405
    "silu":
    lambda: nn.SiLU(),
406
407
408
    "quick_gelu":
    lambda: QuickGELU(),
})
409
410


411
def get_act_fn(act_fn_name: str) -> nn.Module:
412
    """Get an activation function by name."""
413
414
415
416
417
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

418
    return _ACTIVATION_REGISTRY[act_fn_name]
419
420
421
422
423


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
    "gelu": lambda: GeluAndMul(),
    "silu": lambda: SiluAndMul(),
424
    "geglu": lambda: GeluAndMul(),
425
426
427
})


428
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
429
430
431
432
433
434
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

435
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]