activation.py 21.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4

5
import math
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch
import torch.nn as nn
9
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
13
14
15
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
16
from vllm.logger import init_logger
17
from vllm.model_executor.custom_op import CustomOp
18
from vllm.model_executor.utils import set_weight_attrs
19
from vllm.platforms import current_platform
20
from vllm.utils.collection_utils import LazyDict
Woosuk Kwon's avatar
Woosuk Kwon committed
21

22
23
logger = init_logger(__name__)

Woosuk Kwon's avatar
Woosuk Kwon committed
24

25
@CustomOp.register("fatrelu_and_mul")
26
27
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
28

29
30
31
32
33
34
35
36
37
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

38
    def __init__(self, threshold: float = 0.0):
39
40
        super().__init__()
        self.threshold = threshold
41
        if current_platform.is_cuda_alike():
42
            self.op = torch.ops._C.fatrelu_and_mul
43
44
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
45
46
47
48
49
50
51
52
53

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
54
        d = x.shape[-1] // 2
55
        output_shape = x.shape[:-1] + (d,)
56
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
57
        self.op(out, x, self.threshold)
58
        return out
59
60


61
@CustomOp.register("silu_and_mul")
62
class SiluAndMul(CustomOp):
63
64
    """An activation function for SwiGLU.

65
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
66

67
    Shapes:
68
69
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
70
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
71

72
73
    def __init__(self):
        super().__init__()
74
        if current_platform.is_cuda_alike():
75
76
            self.op = torch.ops._C.silu_and_mul
        elif current_platform.is_xpu():
77
            from vllm._ipex_ops import ipex_ops
78

79
            self.op = ipex_ops.silu_and_mul
80
81
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
82

83
84
    @staticmethod
    def forward_native(x: torch.Tensor) -> torch.Tensor:
85
86
87
88
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

89
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
90
        d = x.shape[-1] // 2
91
        output_shape = x.shape[:-1] + (d,)
92
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
93
        self.op(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
94
        return out
95

96
97
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
98
        output_shape = x.shape[:-1] + (d,)
99
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
100
        self.op(out, x)
101
102
        return out

103

104
105
106
107
108
109
110
111
112
113
114
115
116
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.

    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self):
        super().__init__()
117
        if current_platform.is_cuda_alike():
118
119
120
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
121

122
            self.op = ipex_ops.silu_and_mul
123
124
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
125
126
127
128
129
130
131
132

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
133
        output_shape = x.shape[:-1] + (d,)
134
135
136
137
138
139
140
141
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        self.op(out, x)
        return out

    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:


Robert Shaw's avatar
Robert Shaw committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
    """An activation function for GeluAndMulSparse.
    This activation function is used in Gemma3n. It computes:
        up_proj = self.up_proj(x)
        gate_proj = self.gate_proj(x)
        gate_proj = self._gaussian_topk(gate_proj) # sparsity
        activations = self.act_fn(gate_proj) # gelu
        down_proj = self.down_proj(activations * up_proj)
    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, activation_sparsity: float, approximate: str = "none"):
        super().__init__()
        # Gelu.
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
162
163
164
165
166
167
168
        if current_platform.is_rocm() and approximate == "tanh":
            # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
            logger.warning_once(
                "[ROCm] Pytorch's native GELU with tanh approximation is currently "
                "unstable and produces garbage. Fallback to 'none' approximation."
            )
            self.approximate = "none"
Robert Shaw's avatar
Robert Shaw committed
169
170
171

        # Sparsity.
        if activation_sparsity == 0.0:
172
173
            raise ValueError("activation_sparsity is 0.0. Please use GeluAndMul.")
        target_sparsity_tensor = torch.tensor(activation_sparsity, dtype=torch.float32)
Robert Shaw's avatar
Robert Shaw committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        normal_dist = torch.distributions.normal.Normal(0, 1)
        self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)

    def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
        """Get % sparse percentile of the Gaussian distribution."""
        # NOTE(rob): for TP>1, we could all-gather to get the means/std.
        # But we do not do this because in expectation they are the same
        # and in practice the eval scores are good without gathering.
        mean = torch.mean(x, dim=-1, keepdim=True)
        std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
        cutoff_x = mean + std * self.std_multiplier
        return nn.functional.relu(x - cutoff_x)

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        out = self._gaussian_topk(x[..., :d])
        out = F.gelu(out, approximate=self.approximate)
        return out * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


198
@CustomOp.register("gelu_and_mul")
199
class GeluAndMul(CustomOp):
200
201
202
203
204
205
206
207
208
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

209
210
211
212
213
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
214
215
216
217
218
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
219
220
221
222
223
224
        if current_platform.is_rocm() and approximate == "tanh":
            logger.warning_once(
                "[ROCm] PyTorch's native GELU with tanh approximation is unstable "
                "with torch.compile. For native implementation, fallback to 'none' "
                "approximation. The custom kernel implementation is unaffected."
            )
225
226
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
227

228
229
230
231
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
232

233
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
234
        """PyTorch-native implementation equivalent to forward()."""
235
236
237
238
        # TODO: [ROCm] PyTorch's native GELU with tanh is unstable with torch.compile
        approximate = self.approximate
        if current_platform.is_rocm() and approximate == "tanh":
            approximate = "none"
239
        d = x.shape[-1] // 2
240
        return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
241

242
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
243
        d = x.shape[-1] // 2
244
        output_shape = x.shape[:-1] + (d,)
245
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
246
        self.op(out, x)
247
248
        return out

249
250
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
251
        output_shape = x.shape[:-1] + (d,)
252
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
253
        self.op(out, x)
254
255
        return out

256
    def extra_repr(self) -> str:
257
        return f"approximate={repr(self.approximate)}"
258

259

260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
@CustomOp.register("swigluoai_and_mul")
class SwigluOAIAndMul(CustomOp):
    # https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
    def __init__(self, alpha: float = 1.702, limit: float = 7.0):
        super().__init__()
        self.alpha = alpha
        self.limit = limit

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""

        gate, up = x[..., ::2], x[..., 1::2]
        gate = gate.clamp(min=None, max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        glu = gate * torch.sigmoid(gate * self.alpha)
        gated_output = (up + 1) * glu
        return gated_output

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
280
        output_shape = x.shape[:-1] + (d,)
281
282
283
284
285
286
287
288
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
        return out

    def extra_repr(self) -> str:
        return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"


289
@CustomOp.register("gelu_new")
290
class NewGELU(CustomOp):
291
292
293
294
295
296
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
297

298
299
            self.op = ipex_ops.gelu_new

300
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
301
302
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
303
        return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
304

305
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
306
        out = torch.empty_like(x)
307
        self.op(out, x)
308
309
        return out

310
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
311
        return self.op(x)
312

313

314
@CustomOp.register("gelu_fast")
315
class FastGELU(CustomOp):
316
317
318
319
320
321
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
322

323
324
            self.op = ipex_ops.gelu_fast

325
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
326
        """PyTorch-native implementation equivalent to forward()."""
327
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
328

329
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
330
        out = torch.empty_like(x)
331
        self.op(out, x)
332
333
        return out

334
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
335
        return self.op(x)
336

337

338
@CustomOp.register("quick_gelu")
339
340
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
341
342
343
344
345
346
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
347

348
349
            self.op = ipex_ops.gelu_quick

350
351
352
353
354
355
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
356
        self.op(out, x)
357
358
        return out

359
360
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
361
        self.op(out, x)
362
363
        return out

364
365
366
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

367

368
@CustomOp.register("relu2")
369
370
371
372
373
374
375
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
376
        return torch.square(F.relu(x))
377
378

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
379
        # TODO : implement cuda kernels
380
381
382
        return self.forward_native(x)


383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
@CustomOp.register("xielu")
class XIELU(CustomOp):
    """
    Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
    If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
    Otherwise, we emit a single warning and use xIELU Python
    """

    def __init__(
        self,
        alpha_p_init: float = 0.8,
        alpha_n_init: float = 0.8,
        beta: float = 0.5,
        eps: float = -1e-6,
        dtype: torch.dtype = torch.bfloat16,
        with_vector_loads: bool = False,
    ):
        super().__init__()
        self.alpha_p = nn.Parameter(
402
403
404
405
            torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
                0
            )
        )
406
407
        self.alpha_n = nn.Parameter(
            torch.log(
408
409
410
                torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
            ).unsqueeze(0)
        )
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
        self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
        self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
        self.with_vector_loads = with_vector_loads
        # Temporary until xIELU CUDA fully implemented
        self._beta_scalar = float(self.beta.detach().cpu().float().item())
        self._eps_scalar = float(self.eps.detach().cpu().float().item())

        self._xielu_cuda_obj = None
        try:
            import xielu.ops  # noqa: F401

            self._xielu_cuda_obj = torch.classes.xielu.XIELU()
            msg = "Using experimental xIELU CUDA."
            try:
                from torch._dynamo import allow_in_graph

                self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
                msg += " Enabled torch._dynamo for xIELU CUDA."
            except Exception as err:
430
431
432
433
                msg += (
                    f" Could not enable torch._dynamo for xIELU ({err}) - "
                    "this may result in slower performance."
                )
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
                self._xielu_cuda_fn = self._xielu_cuda
            logger.warning_once(msg)
        except Exception as err:
            logger.warning_once(
                "CUDA-fused xIELU not available (%s) –"
                " falling back to a Python version.\n"
                "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
                str(err),
            )

    def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
        alpha_p = nn.functional.softplus(self.alpha_p)
        alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
        return torch.where(
            x > 0,
            alpha_p * x * x + self.beta * x,
450
            (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
451
452
453
454
        )

    def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
        """Firewall function to prevent torch.compile from seeing .item()"""
455
        assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
        original_shape = x.shape
        # CUDA kernel expects 3D tensors, reshape if needed
        while x.dim() < 3:
            x = x.unsqueeze(0)
        if x.dim() > 3:
            x = x.view(-1, 1, x.size(-1))
        if original_shape != x.shape:
            logger.warning_once(
                "Warning: xIELU input tensor expects 3 dimensions"
                " but got (shape: %s). Reshaping to (shape: %s).",
                original_shape,
                x.shape,
            )
        result = self._xielu_cuda_obj.forward(
            x,
            self.alpha_p,
            self.alpha_n,
            # Temporary until xIELU CUDA fully implemented ->
            # self.{beta,eps}.item()
            self._beta_scalar,
            self._eps_scalar,
            self.with_vector_loads,
        )
        return result.view(original_shape)

481
    def forward_native(self, input: torch.Tensor) -> torch.Tensor:
482
483
484
485
486
487
488
489
490
        if self._xielu_cuda_obj is not None and input.is_cuda:
            if not torch._dynamo.is_compiling():
                return self._xielu_cuda_fn(input)
            else:
                logger.warning_once(
                    "torch._dynamo is compiling, using Python version of xIELU."
                )
        return self._xielu_python(input)

491
492
493
    def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
        return self.forward_native(input)

494

495
496
497
498
499
500
501
502
503
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
504
505
        intermediate_size: int,
        input_is_parallel: bool = True,
506
        params_dtype: torch.dtype | None = None,
507
508
509
    ):
        super().__init__()
        self.act = act_module
510
        self.input_is_parallel = input_is_parallel
511
512
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
513
            intermediate_size_per_partition = divide(intermediate_size, tp_size)
514
515
516
517
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
518
        self.scales = nn.Parameter(
519
520
            torch.empty(intermediate_size_per_partition, dtype=params_dtype)
        )
521
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
522

523
    def forward(self, x: torch.Tensor) -> torch.Tensor:
524
525
        return self.act(x) / self.scales

526
527
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
528
529
530
531
532
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
533
534
535
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

536

537
538
539
540
541
_ACTIVATION_REGISTRY = LazyDict(
    {
        "gelu": lambda: nn.GELU(),
        "gelu_fast": lambda: FastGELU(),
        "gelu_new": lambda: NewGELU(),
542
543
544
545
546
547
548
549
550
551
        "gelu_pytorch_tanh": lambda: (
            # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
            logger.warning_once(
                "[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
                "Falling back to GELU(approximate='none')."
            ),
            nn.GELU(approximate="none"),
        )[1]
        if current_platform.is_rocm()
        else nn.GELU(approximate="tanh"),
552
553
554
555
556
557
558
559
560
        "relu": lambda: nn.ReLU(),
        "relu2": lambda: ReLUSquaredActivation(),
        "silu": lambda: nn.SiLU(),
        "quick_gelu": lambda: QuickGELU(),
        "tanh": lambda: nn.Tanh(),
        "sigmoid": lambda: nn.Sigmoid(),
        "xielu": lambda: XIELU(),
    }
)
561
562


563
def get_act_fn(act_fn_name: str) -> nn.Module:
564
    """Get an activation function by name."""
565
    act_fn_name = act_fn_name.lower()
566
567
568
569
570
571
572

    if act_fn_name.startswith("torch.nn.modules."):
        activation_name = act_fn_name.split(".")[-1]
        if activation_name == "identity":
            return nn.Identity()
        act_fn_name = activation_name

573
    if act_fn_name not in _ACTIVATION_REGISTRY:
574
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
575

576
    return _ACTIVATION_REGISTRY[act_fn_name]
577
578


579
580
581
582
583
584
585
586
_ACTIVATION_AND_MUL_REGISTRY = LazyDict(
    {
        "gelu": lambda: GeluAndMul(),
        "silu": lambda: SiluAndMul(),
        "geglu": lambda: GeluAndMul(),
        "swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
    }
)
587
588


589
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
590
591
592
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
593
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
594

595
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]