"requirements/cpu.txt" did not exist on "cfaf49a1673c872d2a06560346efb13695f82f35"
activation.py 15.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4
import math
5
6
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch
import torch.nn as nn
9
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
13
from vllm.model_executor.custom_op import CustomOp
14
from vllm.model_executor.utils import set_weight_attrs
15
from vllm.platforms import current_platform
16
from vllm.utils import LazyDict
Woosuk Kwon's avatar
Woosuk Kwon committed
17
18


19
@CustomOp.register("fatrelu_and_mul")
20
21
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
22

23
24
25
26
27
28
29
30
31
32
33
34
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold
35
        if current_platform.is_cuda_alike():
36
            self.op = torch.ops._C.fatrelu_and_mul
37
38
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
39
40
41
42
43
44
45
46
47

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
48
49
50
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
51
        self.op(out, x, self.threshold)
52
        return out
53
54


55
@CustomOp.register("silu_and_mul")
56
class SiluAndMul(CustomOp):
57
58
    """An activation function for SwiGLU.

59
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
60

61
    Shapes:
62
63
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
64
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
65

66
67
    def __init__(self):
        super().__init__()
68
        if current_platform.is_cuda_alike():
69
70
            self.op = torch.ops._C.silu_and_mul
        elif current_platform.is_xpu():
71
72
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
73
74
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
75

76
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
77
78
79
80
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

81
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
82
83
84
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
85
        self.op(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
86
        return out
87

88
89
90
91
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
92
        self.op(out, x)
93
94
        return out

95
96
97
98
99
100
101
    def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x_reshaped = x.view(-1, x.shape[-1])
        s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
        result = s * x_reshaped[:, d:]
        return result.view(*x.shape[:-1], d)

102

103
104
105
106
107
108
109
110
111
112
113
114
115
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.

    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self):
        super().__init__()
116
        if current_platform.is_cuda_alike():
117
118
119
120
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
121
122
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        self.op(out, x)
        return out

    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:


Robert Shaw's avatar
Robert Shaw committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
    """An activation function for GeluAndMulSparse.
    This activation function is used in Gemma3n. It computes:
        up_proj = self.up_proj(x)
        gate_proj = self.gate_proj(x)
        gate_proj = self._gaussian_topk(gate_proj) # sparsity
        activations = self.act_fn(gate_proj) # gelu
        down_proj = self.down_proj(activations * up_proj)
    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, activation_sparsity: float, approximate: str = "none"):
        super().__init__()
        # Gelu.
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")

        # Sparsity.
        if activation_sparsity == 0.0:
            raise ValueError(
                "activation_sparsity is 0.0. Please use GeluAndMul.")
        target_sparsity_tensor = torch.tensor(activation_sparsity,
                                              dtype=torch.float32)
        normal_dist = torch.distributions.normal.Normal(0, 1)
        self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)

    def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
        """Get % sparse percentile of the Gaussian distribution."""
        # NOTE(rob): for TP>1, we could all-gather to get the means/std.
        # But we do not do this because in expectation they are the same
        # and in practice the eval scores are good without gathering.
        mean = torch.mean(x, dim=-1, keepdim=True)
        std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
        cutoff_x = mean + std * self.std_multiplier
        return nn.functional.relu(x - cutoff_x)

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        out = self._gaussian_topk(x[..., :d])
        out = F.gelu(out, approximate=self.approximate)
        return out * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


191
@CustomOp.register("gelu_and_mul")
192
class GeluAndMul(CustomOp):
193
194
195
196
197
198
199
200
201
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

202
203
204
205
206
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
207
208
209
210
211
212
213
214
215
216
217
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
218

219
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
220
221
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
222
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
223

224
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
225
226
227
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
228
        self.op(out, x)
229
230
        return out

231
232
233
234
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
235
        self.op(out, x)
236
237
        return out

238
239
240
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

241

242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
@CustomOp.register("swigluoai_and_mul")
class SwigluOAIAndMul(CustomOp):
    # https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
    def __init__(self, alpha: float = 1.702, limit: float = 7.0):
        super().__init__()
        self.alpha = alpha
        self.limit = limit

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""

        gate, up = x[..., ::2], x[..., 1::2]
        gate = gate.clamp(min=None, max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        glu = gate * torch.sigmoid(gate * self.alpha)
        gated_output = (up + 1) * glu
        return gated_output

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
        return out

    def extra_repr(self) -> str:
        return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"


271
@CustomOp.register("gelu_new")
272
class NewGELU(CustomOp):
273

274
275
276
277
278
279
280
281
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_new

282
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
283
284
285
286
287
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

288
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
289
        out = torch.empty_like(x)
290
        self.op(out, x)
291
292
        return out

293
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
294
        return self.op(x)
295

296

297
@CustomOp.register("gelu_fast")
298
class FastGELU(CustomOp):
299

300
301
302
303
304
305
306
307
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_fast

308
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
309
310
311
312
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

313
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
314
        out = torch.empty_like(x)
315
        self.op(out, x)
316
317
        return out

318
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
319
        return self.op(x)
320

321

322
@CustomOp.register("quick_gelu")
323
324
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
325
326
327
328
329
330
331
332
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_quick

333
334
335
336
337
338
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
339
        self.op(out, x)
340
341
        return out

342
343
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
344
        self.op(out, x)
345
346
        return out

347
348
349
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

350

351
@CustomOp.register("relu2")
352
353
354
355
356
357
358
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
359
        return torch.square(F.relu(x))
360
361

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
362
        #TODO : implement cuda kenrels
363
364
365
        return self.forward_native(x)


366
367
368
369
370
371
372
373
374
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
375
376
377
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
378
379
380
    ):
        super().__init__()
        self.act = act_module
381
        self.input_is_parallel = input_is_parallel
382
383
384
385
386
387
388
389
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
390
        self.scales = nn.Parameter(
391
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
392
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
393

394
    def forward(self, x: torch.Tensor) -> torch.Tensor:
395
396
        return self.act(x) / self.scales

397
398
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
399
400
401
402
403
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
404
405
406
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

407

408
409
410
411
412
413
414
415
416
417
418
419
420
_ACTIVATION_REGISTRY = LazyDict({
    "gelu":
    lambda: nn.GELU(),
    "gelu_fast":
    lambda: FastGELU(),
    "gelu_new":
    lambda: NewGELU(),
    "gelu_pytorch_tanh":
    lambda: nn.GELU(approximate="tanh"),
    "relu":
    lambda: nn.ReLU(),
    "relu2":
    lambda: ReLUSquaredActivation(),
421
422
    "silu":
    lambda: nn.SiLU(),
423
424
425
    "quick_gelu":
    lambda: QuickGELU(),
})
426
427


428
def get_act_fn(act_fn_name: str) -> nn.Module:
429
    """Get an activation function by name."""
430
431
432
433
434
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

435
    return _ACTIVATION_REGISTRY[act_fn_name]
436
437
438


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
439
440
441
442
443
444
445
446
    "gelu":
    lambda: GeluAndMul(),
    "silu":
    lambda: SiluAndMul(),
    "geglu":
    lambda: GeluAndMul(),
    "swigluoai":
    lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
447
448
449
})


450
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
451
452
453
454
455
456
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

457
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]