activation.py 20.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4
import math
5
6
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch
import torch.nn as nn
9
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
13
from vllm.logger import init_logger
14
from vllm.model_executor.custom_op import CustomOp
15
from vllm.model_executor.utils import set_weight_attrs
16
from vllm.platforms import current_platform
17
from vllm.utils import LazyDict
zhuwenwen's avatar
zhuwenwen committed
18
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
19

20
21
logger = init_logger(__name__)

Woosuk Kwon's avatar
Woosuk Kwon committed
22

23
@CustomOp.register("fatrelu_and_mul")
24
25
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
26

27
28
29
30
31
32
33
34
35
36
37
38
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold
39
        if current_platform.is_cuda_alike():
40
            self.op = torch.ops._C.fatrelu_and_mul
41
42
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
43
44
45
46
47
48
49
50
51

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
52
53
54
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
55
        self.op(out, x, self.threshold)
56
        return out
57
58


59
@CustomOp.register("silu_and_mul")
60
class SiluAndMul(CustomOp):
61
62
    """An activation function for SwiGLU.

63
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
64

65
    Shapes:
66
67
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
68
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
69

70
71
    def __init__(self):
        super().__init__()
72
        if current_platform.is_cuda_alike():
73
            self.op = torch.ops._C.silu_and_mul
zhuwenwen's avatar
zhuwenwen committed
74
            self.op_opt = torch.ops._C.silu_and_mul_opt
75
        elif current_platform.is_xpu():
76
77
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
78
79
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
80

81
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
82
83
84
85
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

86
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
87
88
89
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
90
        if envs.VLLM_USE_OPT_OP:
zhuwenwen's avatar
zhuwenwen committed
91
            self.op_opt(out, x)
zhuwenwen's avatar
zhuwenwen committed
92
        else:
zhuwenwen's avatar
zhuwenwen committed
93
            self.op(out, x) 
Woosuk Kwon's avatar
Woosuk Kwon committed
94
        return out
95

96
97
98
99
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
100
        self.op(out, x)
101
102
        return out

103

104
105
106
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.
107

108
109
110
111
112
113
114
115
116
    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self):
        super().__init__()
117
        if current_platform.is_cuda_alike():
118
119
120
121
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
122
123
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
124
125
126
127
128
129
130

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
131
132
133
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
134
        self.op(out, x)
135
136
        return out

137
138
139
    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

140

Robert Shaw's avatar
Robert Shaw committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
    """An activation function for GeluAndMulSparse.
    This activation function is used in Gemma3n. It computes:
        up_proj = self.up_proj(x)
        gate_proj = self.gate_proj(x)
        gate_proj = self._gaussian_topk(gate_proj) # sparsity
        activations = self.act_fn(gate_proj) # gelu
        down_proj = self.down_proj(activations * up_proj)
    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, activation_sparsity: float, approximate: str = "none"):
        super().__init__()
        # Gelu.
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")

        # Sparsity.
        if activation_sparsity == 0.0:
            raise ValueError(
                "activation_sparsity is 0.0. Please use GeluAndMul.")
        target_sparsity_tensor = torch.tensor(activation_sparsity,
                                              dtype=torch.float32)
        normal_dist = torch.distributions.normal.Normal(0, 1)
        self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)

    def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
        """Get % sparse percentile of the Gaussian distribution."""
        # NOTE(rob): for TP>1, we could all-gather to get the means/std.
        # But we do not do this because in expectation they are the same
        # and in practice the eval scores are good without gathering.
        mean = torch.mean(x, dim=-1, keepdim=True)
        std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
        cutoff_x = mean + std * self.std_multiplier
        return nn.functional.relu(x - cutoff_x)

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        out = self._gaussian_topk(x[..., :d])
        out = F.gelu(out, approximate=self.approximate)
        return out * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


192
@CustomOp.register("gelu_and_mul")
193
class GeluAndMul(CustomOp):
194
195
196
197
198
199
200
201
202
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

203
204
205
206
207
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
208
209
210
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
zhuwenwen's avatar
zhuwenwen committed
211
                self.op_opt = torch.ops._C.gelu_and_mul_opt
212
213
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
zhuwenwen's avatar
zhuwenwen committed
214
                self.op_opt = torch.ops._C.gelu_tanh_and_mul_opt
215
216
217
218
219
220
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
221

222
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
223
224
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
225
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
226

227
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
228
229
230
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
231
232
233
234
        if envs.VLLM_USE_OPT_OP: 
            self.op_opt(out, x)
        else:
            self.op(out, x)
235
236
        return out

237
238
239
240
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
241
        self.op(out, x)
242
243
        return out

244
245
246
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

247

248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
@CustomOp.register("swigluoai_and_mul")
class SwigluOAIAndMul(CustomOp):
    # https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
    def __init__(self, alpha: float = 1.702, limit: float = 7.0):
        super().__init__()
        self.alpha = alpha
        self.limit = limit

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""

        gate, up = x[..., ::2], x[..., 1::2]
        gate = gate.clamp(min=None, max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        glu = gate * torch.sigmoid(gate * self.alpha)
        gated_output = (up + 1) * glu
        return gated_output

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
        return out

    def extra_repr(self) -> str:
        return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"


277
@CustomOp.register("gelu_new")
278
class NewGELU(CustomOp):
279

280
281
282
283
284
285
286
287
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_new

288
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
289
290
291
292
293
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

294
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
295
        out = torch.empty_like(x)
296
        self.op(out, x)
297
298
        return out

299
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
300
        return self.op(x)
301

302

303
@CustomOp.register("gelu_fast")
304
class FastGELU(CustomOp):
305

306
307
308
309
310
311
312
313
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_fast

314
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
315
316
317
318
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

319
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
320
        out = torch.empty_like(x)
321
        self.op(out, x)
322
323
        return out

324
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
325
        return self.op(x)
326

327

328
@CustomOp.register("quick_gelu")
329
330
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
331
332
333
334
335
336
337
338
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_quick

339
340
341
342
343
344
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
345
        self.op(out, x)
346
347
        return out

348
349
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
350
        self.op(out, x)
351
352
        return out

353
354
355
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

356

357
@CustomOp.register("relu2")
358
359
360
361
362
363
364
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
365
        return torch.square(F.relu(x))
366
367

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
co63oc's avatar
co63oc committed
368
        #TODO : implement cuda kernels
369
370
371
        return self.forward_native(x)


372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
@CustomOp.register("xielu")
class XIELU(CustomOp):
    """
    Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
    If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
    Otherwise, we emit a single warning and use xIELU Python
    """

    def __init__(
        self,
        alpha_p_init: float = 0.8,
        alpha_n_init: float = 0.8,
        beta: float = 0.5,
        eps: float = -1e-6,
        dtype: torch.dtype = torch.bfloat16,
        with_vector_loads: bool = False,
    ):
        super().__init__()
        self.alpha_p = nn.Parameter(
            torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) -
                      1).unsqueeze(0))
        self.alpha_n = nn.Parameter(
            torch.log(
                torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) -
                1).unsqueeze(0))
        self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
        self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
        self.with_vector_loads = with_vector_loads
        # Temporary until xIELU CUDA fully implemented
        self._beta_scalar = float(self.beta.detach().cpu().float().item())
        self._eps_scalar = float(self.eps.detach().cpu().float().item())

        self._xielu_cuda_obj = None
        try:
            import xielu.ops  # noqa: F401

            self._xielu_cuda_obj = torch.classes.xielu.XIELU()
            msg = "Using experimental xIELU CUDA."
            try:
                from torch._dynamo import allow_in_graph

                self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
                msg += " Enabled torch._dynamo for xIELU CUDA."
            except Exception as err:
                msg += (f" Could not enable torch._dynamo for xIELU ({err}) - "
                        "this may result in slower performance.")
                self._xielu_cuda_fn = self._xielu_cuda
            logger.warning_once(msg)
        except Exception as err:
            logger.warning_once(
                "CUDA-fused xIELU not available (%s) –"
                " falling back to a Python version.\n"
                "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
                str(err),
            )

    def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
        alpha_p = nn.functional.softplus(self.alpha_p)
        alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
        return torch.where(
            x > 0,
            alpha_p * x * x + self.beta * x,
            (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n +
            self.beta * x,
        )

    def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
        """Firewall function to prevent torch.compile from seeing .item()"""
        assert self._xielu_cuda_obj is not None, (
            "XIELU CUDA object must not be None")
        original_shape = x.shape
        # CUDA kernel expects 3D tensors, reshape if needed
        while x.dim() < 3:
            x = x.unsqueeze(0)
        if x.dim() > 3:
            x = x.view(-1, 1, x.size(-1))
        if original_shape != x.shape:
            logger.warning_once(
                "Warning: xIELU input tensor expects 3 dimensions"
                " but got (shape: %s). Reshaping to (shape: %s).",
                original_shape,
                x.shape,
            )
        result = self._xielu_cuda_obj.forward(
            x,
            self.alpha_p,
            self.alpha_n,
            # Temporary until xIELU CUDA fully implemented ->
            # self.{beta,eps}.item()
            self._beta_scalar,
            self._eps_scalar,
            self.with_vector_loads,
        )
        return result.view(original_shape)

467
    def forward_native(self, input: torch.Tensor) -> torch.Tensor:
468
469
470
471
472
473
474
475
476
        if self._xielu_cuda_obj is not None and input.is_cuda:
            if not torch._dynamo.is_compiling():
                return self._xielu_cuda_fn(input)
            else:
                logger.warning_once(
                    "torch._dynamo is compiling, using Python version of xIELU."
                )
        return self._xielu_python(input)

477
478
479
    def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
        return self.forward_native(input)

480

481
482
483
484
485
486
487
488
489
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
490
491
492
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
493
494
495
    ):
        super().__init__()
        self.act = act_module
496
        self.input_is_parallel = input_is_parallel
497
498
499
500
501
502
503
504
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
505
        self.scales = nn.Parameter(
506
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
507
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
508

509
    def forward(self, x: torch.Tensor) -> torch.Tensor:
510
511
        return self.act(x) / self.scales

512
513
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
514
515
516
517
518
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
519
520
521
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

522

523
524
525
526
527
528
529
530
531
532
533
534
535
_ACTIVATION_REGISTRY = LazyDict({
    "gelu":
    lambda: nn.GELU(),
    "gelu_fast":
    lambda: FastGELU(),
    "gelu_new":
    lambda: NewGELU(),
    "gelu_pytorch_tanh":
    lambda: nn.GELU(approximate="tanh"),
    "relu":
    lambda: nn.ReLU(),
    "relu2":
    lambda: ReLUSquaredActivation(),
536
537
    "silu":
    lambda: nn.SiLU(),
538
539
    "quick_gelu":
    lambda: QuickGELU(),
540
541
542
543
    "tanh":
    lambda: nn.Tanh(),
    "sigmoid":
    lambda: nn.Sigmoid(),
544
545
    "xielu":
    lambda: XIELU(),
546
})
547
548


549
def get_act_fn(act_fn_name: str) -> nn.Module:
550
    """Get an activation function by name."""
551
    act_fn_name = act_fn_name.lower()
552
553
554
555
556
557
558

    if act_fn_name.startswith("torch.nn.modules."):
        activation_name = act_fn_name.split(".")[-1]
        if activation_name == "identity":
            return nn.Identity()
        act_fn_name = activation_name

559
560
561
562
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

563
    return _ACTIVATION_REGISTRY[act_fn_name]
564
565
566


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
567
568
569
570
571
572
573
574
    "gelu":
    lambda: GeluAndMul(),
    "silu":
    lambda: SiluAndMul(),
    "geglu":
    lambda: GeluAndMul(),
    "swigluoai":
    lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
575
576
577
})


578
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
579
580
581
582
583
584
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

585
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]