activation.py 19.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4

5
import math
6
7
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
import torch
import torch.nn as nn
10
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
13
14
15
16
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
17
from vllm.logger import init_logger
18
from vllm.model_executor.custom_op import CustomOp
19
from vllm.model_executor.utils import set_weight_attrs
20
from vllm.platforms import current_platform
21
from vllm.utils import LazyDict
Woosuk Kwon's avatar
Woosuk Kwon committed
22

23
24
logger = init_logger(__name__)

Woosuk Kwon's avatar
Woosuk Kwon committed
25

26
@CustomOp.register("fatrelu_and_mul")
27
28
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
29

30
31
32
33
34
35
36
37
38
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

39
    def __init__(self, threshold: float = 0.0):
40
41
        super().__init__()
        self.threshold = threshold
42
        if current_platform.is_cuda_alike():
43
            self.op = torch.ops._C.fatrelu_and_mul
44
45
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
46
47
48
49
50
51
52
53
54

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
55
        d = x.shape[-1] // 2
56
        output_shape = x.shape[:-1] + (d,)
57
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
58
        self.op(out, x, self.threshold)
59
        return out
60
61


62
@CustomOp.register("silu_and_mul")
63
class SiluAndMul(CustomOp):
64
65
    """An activation function for SwiGLU.

66
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
67

68
    Shapes:
69
70
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
71
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
72

73
74
    def __init__(self):
        super().__init__()
75
        if current_platform.is_cuda_alike():
76
77
            self.op = torch.ops._C.silu_and_mul
        elif current_platform.is_xpu():
78
            from vllm._ipex_ops import ipex_ops
79

80
            self.op = ipex_ops.silu_and_mul
81
82
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
83

84
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
85
86
87
88
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

89
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
90
        d = x.shape[-1] // 2
91
        output_shape = x.shape[:-1] + (d,)
92
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
93
        self.op(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
94
        return out
95

96
97
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
98
        output_shape = x.shape[:-1] + (d,)
99
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
100
        self.op(out, x)
101
102
        return out

103

104
105
106
107
108
109
110
111
112
113
114
115
116
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.

    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self):
        super().__init__()
117
        if current_platform.is_cuda_alike():
118
119
120
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
121

122
            self.op = ipex_ops.silu_and_mul
123
124
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
125
126
127
128
129
130
131
132

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
133
        output_shape = x.shape[:-1] + (d,)
134
135
136
137
138
139
140
141
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        self.op(out, x)
        return out

    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:


Robert Shaw's avatar
Robert Shaw committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
    """An activation function for GeluAndMulSparse.
    This activation function is used in Gemma3n. It computes:
        up_proj = self.up_proj(x)
        gate_proj = self.gate_proj(x)
        gate_proj = self._gaussian_topk(gate_proj) # sparsity
        activations = self.act_fn(gate_proj) # gelu
        down_proj = self.down_proj(activations * up_proj)
    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, activation_sparsity: float, approximate: str = "none"):
        super().__init__()
        # Gelu.
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")

        # Sparsity.
        if activation_sparsity == 0.0:
165
166
            raise ValueError("activation_sparsity is 0.0. Please use GeluAndMul.")
        target_sparsity_tensor = torch.tensor(activation_sparsity, dtype=torch.float32)
Robert Shaw's avatar
Robert Shaw committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        normal_dist = torch.distributions.normal.Normal(0, 1)
        self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)

    def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
        """Get % sparse percentile of the Gaussian distribution."""
        # NOTE(rob): for TP>1, we could all-gather to get the means/std.
        # But we do not do this because in expectation they are the same
        # and in practice the eval scores are good without gathering.
        mean = torch.mean(x, dim=-1, keepdim=True)
        std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
        cutoff_x = mean + std * self.std_multiplier
        return nn.functional.relu(x - cutoff_x)

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        out = self._gaussian_topk(x[..., :d])
        out = F.gelu(out, approximate=self.approximate)
        return out * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


191
@CustomOp.register("gelu_and_mul")
192
class GeluAndMul(CustomOp):
193
194
195
196
197
198
199
200
201
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

202
203
204
205
206
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
207
208
209
210
211
212
213
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
214

215
216
217
218
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
219

220
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
221
222
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
223
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
224

225
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
226
        d = x.shape[-1] // 2
227
        output_shape = x.shape[:-1] + (d,)
228
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
229
        self.op(out, x)
230
231
        return out

232
233
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
234
        output_shape = x.shape[:-1] + (d,)
235
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
236
        self.op(out, x)
237
238
        return out

239
    def extra_repr(self) -> str:
240
        return f"approximate={repr(self.approximate)}"
241

242

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
@CustomOp.register("swigluoai_and_mul")
class SwigluOAIAndMul(CustomOp):
    # https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
    def __init__(self, alpha: float = 1.702, limit: float = 7.0):
        super().__init__()
        self.alpha = alpha
        self.limit = limit

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""

        gate, up = x[..., ::2], x[..., 1::2]
        gate = gate.clamp(min=None, max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        glu = gate * torch.sigmoid(gate * self.alpha)
        gated_output = (up + 1) * glu
        return gated_output

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
263
        output_shape = x.shape[:-1] + (d,)
264
265
266
267
268
269
270
271
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
        return out

    def extra_repr(self) -> str:
        return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"


272
@CustomOp.register("gelu_new")
273
class NewGELU(CustomOp):
274
275
276
277
278
279
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
280

281
282
            self.op = ipex_ops.gelu_new

283
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
284
285
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
286
        return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
287

288
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
289
        out = torch.empty_like(x)
290
        self.op(out, x)
291
292
        return out

293
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
294
        return self.op(x)
295

296

297
@CustomOp.register("gelu_fast")
298
class FastGELU(CustomOp):
299
300
301
302
303
304
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
305

306
307
            self.op = ipex_ops.gelu_fast

308
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
309
        """PyTorch-native implementation equivalent to forward()."""
310
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
311

312
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
313
        out = torch.empty_like(x)
314
        self.op(out, x)
315
316
        return out

317
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
318
        return self.op(x)
319

320

321
@CustomOp.register("quick_gelu")
322
323
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
324
325
326
327
328
329
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
330

331
332
            self.op = ipex_ops.gelu_quick

333
334
335
336
337
338
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
339
        self.op(out, x)
340
341
        return out

342
343
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
344
        self.op(out, x)
345
346
        return out

347
348
349
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

350

351
@CustomOp.register("relu2")
352
353
354
355
356
357
358
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
359
        return torch.square(F.relu(x))
360
361

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
362
        # TODO : implement cuda kernels
363
364
365
        return self.forward_native(x)


366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
@CustomOp.register("xielu")
class XIELU(CustomOp):
    """
    Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
    If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
    Otherwise, we emit a single warning and use xIELU Python
    """

    def __init__(
        self,
        alpha_p_init: float = 0.8,
        alpha_n_init: float = 0.8,
        beta: float = 0.5,
        eps: float = -1e-6,
        dtype: torch.dtype = torch.bfloat16,
        with_vector_loads: bool = False,
    ):
        super().__init__()
        self.alpha_p = nn.Parameter(
385
386
387
388
            torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
                0
            )
        )
389
390
        self.alpha_n = nn.Parameter(
            torch.log(
391
392
393
                torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
            ).unsqueeze(0)
        )
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
        self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
        self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
        self.with_vector_loads = with_vector_loads
        # Temporary until xIELU CUDA fully implemented
        self._beta_scalar = float(self.beta.detach().cpu().float().item())
        self._eps_scalar = float(self.eps.detach().cpu().float().item())

        self._xielu_cuda_obj = None
        try:
            import xielu.ops  # noqa: F401

            self._xielu_cuda_obj = torch.classes.xielu.XIELU()
            msg = "Using experimental xIELU CUDA."
            try:
                from torch._dynamo import allow_in_graph

                self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
                msg += " Enabled torch._dynamo for xIELU CUDA."
            except Exception as err:
413
414
415
416
                msg += (
                    f" Could not enable torch._dynamo for xIELU ({err}) - "
                    "this may result in slower performance."
                )
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
                self._xielu_cuda_fn = self._xielu_cuda
            logger.warning_once(msg)
        except Exception as err:
            logger.warning_once(
                "CUDA-fused xIELU not available (%s) –"
                " falling back to a Python version.\n"
                "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
                str(err),
            )

    def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
        alpha_p = nn.functional.softplus(self.alpha_p)
        alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
        return torch.where(
            x > 0,
            alpha_p * x * x + self.beta * x,
433
            (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
434
435
436
437
        )

    def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
        """Firewall function to prevent torch.compile from seeing .item()"""
438
        assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
        original_shape = x.shape
        # CUDA kernel expects 3D tensors, reshape if needed
        while x.dim() < 3:
            x = x.unsqueeze(0)
        if x.dim() > 3:
            x = x.view(-1, 1, x.size(-1))
        if original_shape != x.shape:
            logger.warning_once(
                "Warning: xIELU input tensor expects 3 dimensions"
                " but got (shape: %s). Reshaping to (shape: %s).",
                original_shape,
                x.shape,
            )
        result = self._xielu_cuda_obj.forward(
            x,
            self.alpha_p,
            self.alpha_n,
            # Temporary until xIELU CUDA fully implemented ->
            # self.{beta,eps}.item()
            self._beta_scalar,
            self._eps_scalar,
            self.with_vector_loads,
        )
        return result.view(original_shape)

464
    def forward_native(self, input: torch.Tensor) -> torch.Tensor:
465
466
467
468
469
470
471
472
473
        if self._xielu_cuda_obj is not None and input.is_cuda:
            if not torch._dynamo.is_compiling():
                return self._xielu_cuda_fn(input)
            else:
                logger.warning_once(
                    "torch._dynamo is compiling, using Python version of xIELU."
                )
        return self._xielu_python(input)

474
475
476
    def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
        return self.forward_native(input)

477

478
479
480
481
482
483
484
485
486
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
487
488
489
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
490
491
492
    ):
        super().__init__()
        self.act = act_module
493
        self.input_is_parallel = input_is_parallel
494
495
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
496
            intermediate_size_per_partition = divide(intermediate_size, tp_size)
497
498
499
500
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
501
        self.scales = nn.Parameter(
502
503
            torch.empty(intermediate_size_per_partition, dtype=params_dtype)
        )
504
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
505

506
    def forward(self, x: torch.Tensor) -> torch.Tensor:
507
508
        return self.act(x) / self.scales

509
510
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
511
512
513
514
515
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
516
517
518
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

519

520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
_ACTIVATION_REGISTRY = LazyDict(
    {
        "gelu": lambda: nn.GELU(),
        "gelu_fast": lambda: FastGELU(),
        "gelu_new": lambda: NewGELU(),
        "gelu_pytorch_tanh": lambda: nn.GELU(approximate="tanh"),
        "relu": lambda: nn.ReLU(),
        "relu2": lambda: ReLUSquaredActivation(),
        "silu": lambda: nn.SiLU(),
        "quick_gelu": lambda: QuickGELU(),
        "tanh": lambda: nn.Tanh(),
        "sigmoid": lambda: nn.Sigmoid(),
        "xielu": lambda: XIELU(),
    }
)
535
536


537
def get_act_fn(act_fn_name: str) -> nn.Module:
538
    """Get an activation function by name."""
539
    act_fn_name = act_fn_name.lower()
540
541
542
543
544
545
546

    if act_fn_name.startswith("torch.nn.modules."):
        activation_name = act_fn_name.split(".")[-1]
        if activation_name == "identity":
            return nn.Identity()
        act_fn_name = activation_name

547
    if act_fn_name not in _ACTIVATION_REGISTRY:
548
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
549

550
    return _ACTIVATION_REGISTRY[act_fn_name]
551
552


553
554
555
556
557
558
559
560
_ACTIVATION_AND_MUL_REGISTRY = LazyDict(
    {
        "gelu": lambda: GeluAndMul(),
        "silu": lambda: SiluAndMul(),
        "geglu": lambda: GeluAndMul(),
        "swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
    }
)
561
562


563
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
564
565
566
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
567
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
568

569
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]