activation.py 19.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4

5
import math
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch
import torch.nn as nn
9
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
13
14
15
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
16
from vllm.logger import init_logger
17
from vllm.model_executor.custom_op import CustomOp
18
from vllm.model_executor.utils import set_weight_attrs
19
from vllm.platforms import current_platform
20
from vllm.utils.collection_utils import LazyDict
Woosuk Kwon's avatar
Woosuk Kwon committed
21

22
23
logger = init_logger(__name__)

Woosuk Kwon's avatar
Woosuk Kwon committed
24

25
@CustomOp.register("fatrelu_and_mul")
26
27
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
28

29
30
31
32
33
34
35
36
37
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

38
    def __init__(self, threshold: float = 0.0):
39
40
        super().__init__()
        self.threshold = threshold
41
        if current_platform.is_cuda_alike():
42
            self.op = torch.ops._C.fatrelu_and_mul
43
44
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
45
46
47
48
49
50
51
52
53

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
54
        d = x.shape[-1] // 2
55
        output_shape = x.shape[:-1] + (d,)
56
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
57
        self.op(out, x, self.threshold)
58
        return out
59
60


61
@CustomOp.register("silu_and_mul")
62
class SiluAndMul(CustomOp):
63
64
    """An activation function for SwiGLU.

65
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
66

67
    Shapes:
68
69
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
70
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
71

72
73
    def __init__(self):
        super().__init__()
74
        if current_platform.is_cuda_alike():
75
76
            self.op = torch.ops._C.silu_and_mul
        elif current_platform.is_xpu():
77
            from vllm._ipex_ops import ipex_ops
78

79
            self.op = ipex_ops.silu_and_mul
80
81
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
82

83
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
84
85
86
87
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

88
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
89
        d = x.shape[-1] // 2
90
        output_shape = x.shape[:-1] + (d,)
91
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
92
        self.op(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
93
        return out
94

95
96
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
97
        output_shape = x.shape[:-1] + (d,)
98
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
99
        self.op(out, x)
100
101
        return out

102

103
104
105
106
107
108
109
110
111
112
113
114
115
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.

    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self):
        super().__init__()
116
        if current_platform.is_cuda_alike():
117
118
119
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
120

121
            self.op = ipex_ops.silu_and_mul
122
123
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
124
125
126
127
128
129
130
131

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
132
        output_shape = x.shape[:-1] + (d,)
133
134
135
136
137
138
139
140
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        self.op(out, x)
        return out

    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:


Robert Shaw's avatar
Robert Shaw committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
    """An activation function for GeluAndMulSparse.
    This activation function is used in Gemma3n. It computes:
        up_proj = self.up_proj(x)
        gate_proj = self.gate_proj(x)
        gate_proj = self._gaussian_topk(gate_proj) # sparsity
        activations = self.act_fn(gate_proj) # gelu
        down_proj = self.down_proj(activations * up_proj)
    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, activation_sparsity: float, approximate: str = "none"):
        super().__init__()
        # Gelu.
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")

        # Sparsity.
        if activation_sparsity == 0.0:
164
165
            raise ValueError("activation_sparsity is 0.0. Please use GeluAndMul.")
        target_sparsity_tensor = torch.tensor(activation_sparsity, dtype=torch.float32)
Robert Shaw's avatar
Robert Shaw committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        normal_dist = torch.distributions.normal.Normal(0, 1)
        self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)

    def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
        """Get % sparse percentile of the Gaussian distribution."""
        # NOTE(rob): for TP>1, we could all-gather to get the means/std.
        # But we do not do this because in expectation they are the same
        # and in practice the eval scores are good without gathering.
        mean = torch.mean(x, dim=-1, keepdim=True)
        std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
        cutoff_x = mean + std * self.std_multiplier
        return nn.functional.relu(x - cutoff_x)

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        out = self._gaussian_topk(x[..., :d])
        out = F.gelu(out, approximate=self.approximate)
        return out * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


190
@CustomOp.register("gelu_and_mul")
191
class GeluAndMul(CustomOp):
192
193
194
195
196
197
198
199
200
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

201
202
203
204
205
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
206
207
208
209
210
211
212
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
213

214
215
216
217
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
218

219
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
220
221
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
222
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
223

224
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
225
        d = x.shape[-1] // 2
226
        output_shape = x.shape[:-1] + (d,)
227
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
228
        self.op(out, x)
229
230
        return out

231
232
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
233
        output_shape = x.shape[:-1] + (d,)
234
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
235
        self.op(out, x)
236
237
        return out

238
    def extra_repr(self) -> str:
239
        return f"approximate={repr(self.approximate)}"
240

241

242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
@CustomOp.register("swigluoai_and_mul")
class SwigluOAIAndMul(CustomOp):
    # https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
    def __init__(self, alpha: float = 1.702, limit: float = 7.0):
        super().__init__()
        self.alpha = alpha
        self.limit = limit

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""

        gate, up = x[..., ::2], x[..., 1::2]
        gate = gate.clamp(min=None, max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        glu = gate * torch.sigmoid(gate * self.alpha)
        gated_output = (up + 1) * glu
        return gated_output

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
262
        output_shape = x.shape[:-1] + (d,)
263
264
265
266
267
268
269
270
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
        return out

    def extra_repr(self) -> str:
        return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"


271
@CustomOp.register("gelu_new")
272
class NewGELU(CustomOp):
273
274
275
276
277
278
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
279

280
281
            self.op = ipex_ops.gelu_new

282
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
283
284
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
285
        return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
286

287
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
288
        out = torch.empty_like(x)
289
        self.op(out, x)
290
291
        return out

292
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
293
        return self.op(x)
294

295

296
@CustomOp.register("gelu_fast")
297
class FastGELU(CustomOp):
298
299
300
301
302
303
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
304

305
306
            self.op = ipex_ops.gelu_fast

307
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
308
        """PyTorch-native implementation equivalent to forward()."""
309
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
310

311
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
312
        out = torch.empty_like(x)
313
        self.op(out, x)
314
315
        return out

316
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
317
        return self.op(x)
318

319

320
@CustomOp.register("quick_gelu")
321
322
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
323
324
325
326
327
328
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
329

330
331
            self.op = ipex_ops.gelu_quick

332
333
334
335
336
337
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
338
        self.op(out, x)
339
340
        return out

341
342
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
343
        self.op(out, x)
344
345
        return out

346
347
348
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

349

350
@CustomOp.register("relu2")
351
352
353
354
355
356
357
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
358
        return torch.square(F.relu(x))
359
360

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
361
        # TODO : implement cuda kernels
362
363
364
        return self.forward_native(x)


365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
@CustomOp.register("xielu")
class XIELU(CustomOp):
    """
    Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
    If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
    Otherwise, we emit a single warning and use xIELU Python
    """

    def __init__(
        self,
        alpha_p_init: float = 0.8,
        alpha_n_init: float = 0.8,
        beta: float = 0.5,
        eps: float = -1e-6,
        dtype: torch.dtype = torch.bfloat16,
        with_vector_loads: bool = False,
    ):
        super().__init__()
        self.alpha_p = nn.Parameter(
384
385
386
387
            torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
                0
            )
        )
388
389
        self.alpha_n = nn.Parameter(
            torch.log(
390
391
392
                torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
            ).unsqueeze(0)
        )
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
        self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
        self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
        self.with_vector_loads = with_vector_loads
        # Temporary until xIELU CUDA fully implemented
        self._beta_scalar = float(self.beta.detach().cpu().float().item())
        self._eps_scalar = float(self.eps.detach().cpu().float().item())

        self._xielu_cuda_obj = None
        try:
            import xielu.ops  # noqa: F401

            self._xielu_cuda_obj = torch.classes.xielu.XIELU()
            msg = "Using experimental xIELU CUDA."
            try:
                from torch._dynamo import allow_in_graph

                self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
                msg += " Enabled torch._dynamo for xIELU CUDA."
            except Exception as err:
412
413
414
415
                msg += (
                    f" Could not enable torch._dynamo for xIELU ({err}) - "
                    "this may result in slower performance."
                )
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
                self._xielu_cuda_fn = self._xielu_cuda
            logger.warning_once(msg)
        except Exception as err:
            logger.warning_once(
                "CUDA-fused xIELU not available (%s) –"
                " falling back to a Python version.\n"
                "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
                str(err),
            )

    def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
        alpha_p = nn.functional.softplus(self.alpha_p)
        alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
        return torch.where(
            x > 0,
            alpha_p * x * x + self.beta * x,
432
            (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
433
434
435
436
        )

    def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
        """Firewall function to prevent torch.compile from seeing .item()"""
437
        assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
        original_shape = x.shape
        # CUDA kernel expects 3D tensors, reshape if needed
        while x.dim() < 3:
            x = x.unsqueeze(0)
        if x.dim() > 3:
            x = x.view(-1, 1, x.size(-1))
        if original_shape != x.shape:
            logger.warning_once(
                "Warning: xIELU input tensor expects 3 dimensions"
                " but got (shape: %s). Reshaping to (shape: %s).",
                original_shape,
                x.shape,
            )
        result = self._xielu_cuda_obj.forward(
            x,
            self.alpha_p,
            self.alpha_n,
            # Temporary until xIELU CUDA fully implemented ->
            # self.{beta,eps}.item()
            self._beta_scalar,
            self._eps_scalar,
            self.with_vector_loads,
        )
        return result.view(original_shape)

463
    def forward_native(self, input: torch.Tensor) -> torch.Tensor:
464
465
466
467
468
469
470
471
472
        if self._xielu_cuda_obj is not None and input.is_cuda:
            if not torch._dynamo.is_compiling():
                return self._xielu_cuda_fn(input)
            else:
                logger.warning_once(
                    "torch._dynamo is compiling, using Python version of xIELU."
                )
        return self._xielu_python(input)

473
474
475
    def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
        return self.forward_native(input)

476

477
478
479
480
481
482
483
484
485
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
486
487
        intermediate_size: int,
        input_is_parallel: bool = True,
488
        params_dtype: torch.dtype | None = None,
489
490
491
    ):
        super().__init__()
        self.act = act_module
492
        self.input_is_parallel = input_is_parallel
493
494
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
495
            intermediate_size_per_partition = divide(intermediate_size, tp_size)
496
497
498
499
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
500
        self.scales = nn.Parameter(
501
502
            torch.empty(intermediate_size_per_partition, dtype=params_dtype)
        )
503
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
504

505
    def forward(self, x: torch.Tensor) -> torch.Tensor:
506
507
        return self.act(x) / self.scales

508
509
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
510
511
512
513
514
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
515
516
517
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

518

519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
_ACTIVATION_REGISTRY = LazyDict(
    {
        "gelu": lambda: nn.GELU(),
        "gelu_fast": lambda: FastGELU(),
        "gelu_new": lambda: NewGELU(),
        "gelu_pytorch_tanh": lambda: nn.GELU(approximate="tanh"),
        "relu": lambda: nn.ReLU(),
        "relu2": lambda: ReLUSquaredActivation(),
        "silu": lambda: nn.SiLU(),
        "quick_gelu": lambda: QuickGELU(),
        "tanh": lambda: nn.Tanh(),
        "sigmoid": lambda: nn.Sigmoid(),
        "xielu": lambda: XIELU(),
    }
)
534
535


536
def get_act_fn(act_fn_name: str) -> nn.Module:
537
    """Get an activation function by name."""
538
    act_fn_name = act_fn_name.lower()
539
540
541
542
543
544
545

    if act_fn_name.startswith("torch.nn.modules."):
        activation_name = act_fn_name.split(".")[-1]
        if activation_name == "identity":
            return nn.Identity()
        act_fn_name = activation_name

546
    if act_fn_name not in _ACTIVATION_REGISTRY:
547
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
548

549
    return _ACTIVATION_REGISTRY[act_fn_name]
550
551


552
553
554
555
556
557
558
559
_ACTIVATION_AND_MUL_REGISTRY = LazyDict(
    {
        "gelu": lambda: GeluAndMul(),
        "silu": lambda: SiluAndMul(),
        "geglu": lambda: GeluAndMul(),
        "swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
    }
)
560
561


562
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
563
564
565
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
566
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
567

568
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]