activation.py 22 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4

5
import math
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch
import torch.nn as nn
9
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
13
14
15
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
16
from vllm.logger import init_logger
17
from vllm.model_executor.custom_op import CustomOp
18
from vllm.model_executor.utils import set_weight_attrs
19
from vllm.platforms import current_platform
20
from vllm.utils.collection_utils import LazyDict
Woosuk Kwon's avatar
Woosuk Kwon committed
21

22
23
logger = init_logger(__name__)

Woosuk Kwon's avatar
Woosuk Kwon committed
24

25
# --8<-- [start:fatrelu_and_mul]
26
@CustomOp.register("fatrelu_and_mul")
27
28
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
29

30
31
32
33
34
35
36
37
38
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

39
40
    # --8<-- [end:fatrelu_and_mul]

41
    def __init__(self, threshold: float = 0.0):
42
43
        super().__init__()
        self.threshold = threshold
44
        if current_platform.is_cuda_alike():
45
            self.op = torch.ops._C.fatrelu_and_mul
46
47
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
48
49
50
51
52
53
54
55
56

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
57
        d = x.shape[-1] // 2
58
        output_shape = x.shape[:-1] + (d,)
59
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
60
        self.op(out, x, self.threshold)
61
        return out
62
63


64
# --8<-- [start:silu_and_mul]
65
@CustomOp.register("silu_and_mul")
66
class SiluAndMul(CustomOp):
67
68
    """An activation function for SwiGLU.

69
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
70

71
    Shapes:
72
73
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
74
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
75

76
77
    # --8<-- [end:silu_and_mul]

78
79
    def __init__(self, *, compile_native: bool = True):
        super().__init__(compile_native=compile_native)
80
        if current_platform.is_cuda_alike():
81
82
            self.op = torch.ops._C.silu_and_mul
        elif current_platform.is_xpu():
83
            from vllm._ipex_ops import ipex_ops
84

85
            self.op = ipex_ops.silu_and_mul
86
87
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
88

89
90
    @staticmethod
    def forward_native(x: torch.Tensor) -> torch.Tensor:
91
92
93
94
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

95
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
96
        d = x.shape[-1] // 2
97
        output_shape = x.shape[:-1] + (d,)
98
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
99
        self.op(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
100
        return out
101

102
103
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
104
        output_shape = x.shape[:-1] + (d,)
105
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
106
        self.op(out, x)
107
108
        return out

109

110
# --8<-- [start:mul_and_silu]
111
112
113
114
115
116
117
118
119
120
121
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.

    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

122
123
    # --8<-- [end:mul_and_silu]

124
125
    def __init__(self):
        super().__init__()
126
        if current_platform.is_cuda_alike():
127
128
129
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
130

131
            self.op = ipex_ops.silu_and_mul
132
133
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
134
135
136
137
138
139
140
141

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
142
        output_shape = x.shape[:-1] + (d,)
143
144
145
146
147
148
149
150
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        self.op(out, x)
        return out

    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:


151
# --8<-- [start:gelu_and_mul_sparse]
Robert Shaw's avatar
Robert Shaw committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
    """An activation function for GeluAndMulSparse.
    This activation function is used in Gemma3n. It computes:
        up_proj = self.up_proj(x)
        gate_proj = self.gate_proj(x)
        gate_proj = self._gaussian_topk(gate_proj) # sparsity
        activations = self.act_fn(gate_proj) # gelu
        down_proj = self.down_proj(activations * up_proj)
    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

166
167
    # --8<-- [end:gelu_and_mul_sparse]

Robert Shaw's avatar
Robert Shaw committed
168
169
170
171
172
173
    def __init__(self, activation_sparsity: float, approximate: str = "none"):
        super().__init__()
        # Gelu.
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
174
175
176
177
178
179
180
        if current_platform.is_rocm() and approximate == "tanh":
            # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
            logger.warning_once(
                "[ROCm] Pytorch's native GELU with tanh approximation is currently "
                "unstable and produces garbage. Fallback to 'none' approximation."
            )
            self.approximate = "none"
Robert Shaw's avatar
Robert Shaw committed
181
182
183

        # Sparsity.
        if activation_sparsity == 0.0:
184
185
            raise ValueError("activation_sparsity is 0.0. Please use GeluAndMul.")
        target_sparsity_tensor = torch.tensor(activation_sparsity, dtype=torch.float32)
Robert Shaw's avatar
Robert Shaw committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        normal_dist = torch.distributions.normal.Normal(0, 1)
        self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)

    def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
        """Get % sparse percentile of the Gaussian distribution."""
        # NOTE(rob): for TP>1, we could all-gather to get the means/std.
        # But we do not do this because in expectation they are the same
        # and in practice the eval scores are good without gathering.
        mean = torch.mean(x, dim=-1, keepdim=True)
        std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
        cutoff_x = mean + std * self.std_multiplier
        return nn.functional.relu(x - cutoff_x)

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        out = self._gaussian_topk(x[..., :d])
        out = F.gelu(out, approximate=self.approximate)
        return out * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


210
# --8<-- [start:gelu_and_mul]
211
@CustomOp.register("gelu_and_mul")
212
class GeluAndMul(CustomOp):
213
214
215
216
217
218
219
220
221
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

222
223
    # --8<-- [end:gelu_and_mul]

224
225
226
227
228
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
229
230
231
232
233
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
234
235
236
237
238
239
        if current_platform.is_rocm() and approximate == "tanh":
            logger.warning_once(
                "[ROCm] PyTorch's native GELU with tanh approximation is unstable "
                "with torch.compile. For native implementation, fallback to 'none' "
                "approximation. The custom kernel implementation is unaffected."
            )
240
241
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
242

243
244
245
246
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
247

248
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
249
        """PyTorch-native implementation equivalent to forward()."""
250
251
252
253
        # TODO: [ROCm] PyTorch's native GELU with tanh is unstable with torch.compile
        approximate = self.approximate
        if current_platform.is_rocm() and approximate == "tanh":
            approximate = "none"
254
        d = x.shape[-1] // 2
255
        return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
256

257
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
258
        d = x.shape[-1] // 2
259
        output_shape = x.shape[:-1] + (d,)
260
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
261
        self.op(out, x)
262
263
        return out

264
265
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
266
        output_shape = x.shape[:-1] + (d,)
267
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
268
        self.op(out, x)
269
270
        return out

271
    def extra_repr(self) -> str:
272
        return f"approximate={repr(self.approximate)}"
273

274

275
# --8<-- [start:swigluoai_and_mul]
276
277
278
@CustomOp.register("swigluoai_and_mul")
class SwigluOAIAndMul(CustomOp):
    # https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
279
280
    # --8<-- [end:swigluoai_and_mul]

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    def __init__(self, alpha: float = 1.702, limit: float = 7.0):
        super().__init__()
        self.alpha = alpha
        self.limit = limit

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""

        gate, up = x[..., ::2], x[..., 1::2]
        gate = gate.clamp(min=None, max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        glu = gate * torch.sigmoid(gate * self.alpha)
        gated_output = (up + 1) * glu
        return gated_output

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
298
        output_shape = x.shape[:-1] + (d,)
299
300
301
302
303
304
305
306
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
        return out

    def extra_repr(self) -> str:
        return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"


307
# --8<-- [start:gelu_new]
308
@CustomOp.register("gelu_new")
309
class NewGELU(CustomOp):
310
311
    # --8<-- [end:gelu_new]

312
313
314
315
316
317
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
318

319
320
            self.op = ipex_ops.gelu_new

321
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
322
323
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
324
        return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
325

326
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
327
        out = torch.empty_like(x)
328
        self.op(out, x)
329
330
        return out

331
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
332
        return self.op(x)
333

334

335
# --8<-- [start:gelu_fast]
336
@CustomOp.register("gelu_fast")
337
class FastGELU(CustomOp):
338
339
    # --8<-- [end:gelu_fast]

340
341
342
343
344
345
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
346

347
348
            self.op = ipex_ops.gelu_fast

349
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
350
        """PyTorch-native implementation equivalent to forward()."""
351
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
352

353
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
354
        out = torch.empty_like(x)
355
        self.op(out, x)
356
357
        return out

358
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
359
        return self.op(x)
360

361

362
# --8<-- [start:quick_gelu]
363
@CustomOp.register("quick_gelu")
364
365
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
366
367
    # --8<-- [end:quick_gelu]

368
369
370
371
372
373
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
374

375
376
            self.op = ipex_ops.gelu_quick

377
378
379
380
381
382
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
383
        self.op(out, x)
384
385
        return out

386
387
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
388
        self.op(out, x)
389
390
        return out

391
392
393
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

394

395
# --8<-- [start:relu2]
396
@CustomOp.register("relu2")
397
398
399
400
401
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

402
403
    # --8<-- [end:relu2]

404
405
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
406
        return torch.square(F.relu(x))
407
408

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
409
        # TODO : implement cuda kernels
410
411
412
        return self.forward_native(x)


413
# --8<-- [start:xielu]
414
415
416
417
418
419
420
421
@CustomOp.register("xielu")
class XIELU(CustomOp):
    """
    Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
    If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
    Otherwise, we emit a single warning and use xIELU Python
    """

422
423
    # --8<-- [end:xielu]

424
425
426
427
428
429
430
431
432
433
434
    def __init__(
        self,
        alpha_p_init: float = 0.8,
        alpha_n_init: float = 0.8,
        beta: float = 0.5,
        eps: float = -1e-6,
        dtype: torch.dtype = torch.bfloat16,
        with_vector_loads: bool = False,
    ):
        super().__init__()
        self.alpha_p = nn.Parameter(
435
436
437
438
            torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
                0
            )
        )
439
440
        self.alpha_n = nn.Parameter(
            torch.log(
441
442
443
                torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
            ).unsqueeze(0)
        )
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
        self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
        self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
        self.with_vector_loads = with_vector_loads
        # Temporary until xIELU CUDA fully implemented
        self._beta_scalar = float(self.beta.detach().cpu().float().item())
        self._eps_scalar = float(self.eps.detach().cpu().float().item())

        self._xielu_cuda_obj = None
        try:
            import xielu.ops  # noqa: F401

            self._xielu_cuda_obj = torch.classes.xielu.XIELU()
            msg = "Using experimental xIELU CUDA."
            try:
                from torch._dynamo import allow_in_graph

                self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
                msg += " Enabled torch._dynamo for xIELU CUDA."
            except Exception as err:
463
464
465
466
                msg += (
                    f" Could not enable torch._dynamo for xIELU ({err}) - "
                    "this may result in slower performance."
                )
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
                self._xielu_cuda_fn = self._xielu_cuda
            logger.warning_once(msg)
        except Exception as err:
            logger.warning_once(
                "CUDA-fused xIELU not available (%s) –"
                " falling back to a Python version.\n"
                "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
                str(err),
            )

    def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
        alpha_p = nn.functional.softplus(self.alpha_p)
        alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
        return torch.where(
            x > 0,
            alpha_p * x * x + self.beta * x,
483
            (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
484
485
486
487
        )

    def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
        """Firewall function to prevent torch.compile from seeing .item()"""
488
        assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
        original_shape = x.shape
        # CUDA kernel expects 3D tensors, reshape if needed
        while x.dim() < 3:
            x = x.unsqueeze(0)
        if x.dim() > 3:
            x = x.view(-1, 1, x.size(-1))
        if original_shape != x.shape:
            logger.warning_once(
                "Warning: xIELU input tensor expects 3 dimensions"
                " but got (shape: %s). Reshaping to (shape: %s).",
                original_shape,
                x.shape,
            )
        result = self._xielu_cuda_obj.forward(
            x,
            self.alpha_p,
            self.alpha_n,
            # Temporary until xIELU CUDA fully implemented ->
            # self.{beta,eps}.item()
            self._beta_scalar,
            self._eps_scalar,
            self.with_vector_loads,
        )
        return result.view(original_shape)

514
    def forward_native(self, input: torch.Tensor) -> torch.Tensor:
515
516
517
518
519
520
521
522
523
        if self._xielu_cuda_obj is not None and input.is_cuda:
            if not torch._dynamo.is_compiling():
                return self._xielu_cuda_fn(input)
            else:
                logger.warning_once(
                    "torch._dynamo is compiling, using Python version of xIELU."
                )
        return self._xielu_python(input)

524
525
526
    def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
        return self.forward_native(input)

527

528
529
530
531
532
533
534
535
536
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
537
538
        intermediate_size: int,
        input_is_parallel: bool = True,
539
        params_dtype: torch.dtype | None = None,
540
541
542
    ):
        super().__init__()
        self.act = act_module
543
        self.input_is_parallel = input_is_parallel
544
545
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
546
            intermediate_size_per_partition = divide(intermediate_size, tp_size)
547
548
549
550
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
551
        self.scales = nn.Parameter(
552
553
            torch.empty(intermediate_size_per_partition, dtype=params_dtype)
        )
554
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
555

556
    def forward(self, x: torch.Tensor) -> torch.Tensor:
557
558
        return self.act(x) / self.scales

559
560
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
561
562
563
564
565
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
566
567
568
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

569

570
571
572
573
574
_ACTIVATION_REGISTRY = LazyDict(
    {
        "gelu": lambda: nn.GELU(),
        "gelu_fast": lambda: FastGELU(),
        "gelu_new": lambda: NewGELU(),
575
576
577
578
579
580
581
582
583
584
        "gelu_pytorch_tanh": lambda: (
            # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
            logger.warning_once(
                "[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
                "Falling back to GELU(approximate='none')."
            ),
            nn.GELU(approximate="none"),
        )[1]
        if current_platform.is_rocm()
        else nn.GELU(approximate="tanh"),
585
586
587
588
589
590
591
592
593
        "relu": lambda: nn.ReLU(),
        "relu2": lambda: ReLUSquaredActivation(),
        "silu": lambda: nn.SiLU(),
        "quick_gelu": lambda: QuickGELU(),
        "tanh": lambda: nn.Tanh(),
        "sigmoid": lambda: nn.Sigmoid(),
        "xielu": lambda: XIELU(),
    }
)
594
595


596
def get_act_fn(act_fn_name: str) -> nn.Module:
597
    """Get an activation function by name."""
598
    act_fn_name = act_fn_name.lower()
599
600
601
602
603
604
605

    if act_fn_name.startswith("torch.nn.modules."):
        activation_name = act_fn_name.split(".")[-1]
        if activation_name == "identity":
            return nn.Identity()
        act_fn_name = activation_name

606
    if act_fn_name not in _ACTIVATION_REGISTRY:
607
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
608

609
    return _ACTIVATION_REGISTRY[act_fn_name]
610
611


612
613
614
615
616
617
618
619
_ACTIVATION_AND_MUL_REGISTRY = LazyDict(
    {
        "gelu": lambda: GeluAndMul(),
        "silu": lambda: SiluAndMul(),
        "geglu": lambda: GeluAndMul(),
        "swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
    }
)
620
621


622
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
623
624
625
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
626
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
627

628
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]