activation.py 21.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4

5
import math
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch
import torch.nn as nn
9
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
13
14
15
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
16
from vllm.logger import init_logger
17
from vllm.model_executor.custom_op import CustomOp
18
from vllm.model_executor.utils import set_weight_attrs
19
from vllm.platforms import current_platform
20
from vllm.utils.collection_utils import LazyDict
zhuwenwen's avatar
zhuwenwen committed
21
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
22

23
24
logger = init_logger(__name__)

Woosuk Kwon's avatar
Woosuk Kwon committed
25

26
@CustomOp.register("fatrelu_and_mul")
27
28
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
29

30
31
32
33
34
35
36
37
38
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

39
    def __init__(self, threshold: float = 0.0):
40
41
        super().__init__()
        self.threshold = threshold
42
        if current_platform.is_cuda_alike():
43
            self.op = torch.ops._C.fatrelu_and_mul
44
45
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
46
47
48
49
50
51
52
53
54

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
55
        d = x.shape[-1] // 2
56
        output_shape = x.shape[:-1] + (d,)
57
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
58
        self.op(out, x, self.threshold)
59
        return out
60
61


62
@CustomOp.register("silu_and_mul")
63
class SiluAndMul(CustomOp):
64
65
    """An activation function for SwiGLU.

66
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
67

68
    Shapes:
69
70
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
71
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
72

73
74
    def __init__(self):
        super().__init__()
75
        if current_platform.is_cuda_alike():
76
            self.op = torch.ops._C.silu_and_mul
zhuwenwen's avatar
zhuwenwen committed
77
            self.op_opt = torch.ops._C.silu_and_mul_opt
78
        elif current_platform.is_xpu():
79
            from vllm._ipex_ops import ipex_ops
80

81
            self.op = ipex_ops.silu_and_mul
82
83
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
84

85
86
    @staticmethod
    def forward_native(x: torch.Tensor) -> torch.Tensor:
87
        """PyTorch-native implementation equivalent to forward()."""
zhuwenwen's avatar
zhuwenwen committed
88
89
90
91
        if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
            return self.forward_cuda(x)
        else:
            d = x.shape[-1] // 2
92
93
        return F.silu(x[..., :d]) * x[..., d:]

94
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
95
        d = x.shape[-1] // 2
96
        output_shape = x.shape[:-1] + (d,)
97
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
98
        if envs.VLLM_USE_OPT_OP:
zhuwenwen's avatar
zhuwenwen committed
99
            self.op_opt(out, x)
zhuwenwen's avatar
zhuwenwen committed
100
        else:
zhuwenwen's avatar
zhuwenwen committed
101
            self.op(out, x) 
Woosuk Kwon's avatar
Woosuk Kwon committed
102
        return out
103

104
105
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
106
        output_shape = x.shape[:-1] + (d,)
107
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
108
        self.op(out, x)
109
110
        return out

111

112
113
114
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.
115

116
117
118
119
120
121
122
123
124
    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self):
        super().__init__()
125
        if current_platform.is_cuda_alike():
126
127
128
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
129

130
            self.op = ipex_ops.silu_and_mul
131
132
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
133
134
135
136
137
138
139

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
140
        d = x.shape[-1] // 2
141
        output_shape = x.shape[:-1] + (d,)
142
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
143
        self.op(out, x)
144
145
        return out

146
147
148
    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

149

Robert Shaw's avatar
Robert Shaw committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
    """An activation function for GeluAndMulSparse.
    This activation function is used in Gemma3n. It computes:
        up_proj = self.up_proj(x)
        gate_proj = self.gate_proj(x)
        gate_proj = self._gaussian_topk(gate_proj) # sparsity
        activations = self.act_fn(gate_proj) # gelu
        down_proj = self.down_proj(activations * up_proj)
    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, activation_sparsity: float, approximate: str = "none"):
        super().__init__()
        # Gelu.
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
170
171
172
173
174
175
176
        if current_platform.is_rocm() and approximate == "tanh":
            # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
            logger.warning_once(
                "[ROCm] Pytorch's native GELU with tanh approximation is currently "
                "unstable and produces garbage. Fallback to 'none' approximation."
            )
            self.approximate = "none"
Robert Shaw's avatar
Robert Shaw committed
177
178
179

        # Sparsity.
        if activation_sparsity == 0.0:
180
181
            raise ValueError("activation_sparsity is 0.0. Please use GeluAndMul.")
        target_sparsity_tensor = torch.tensor(activation_sparsity, dtype=torch.float32)
Robert Shaw's avatar
Robert Shaw committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        normal_dist = torch.distributions.normal.Normal(0, 1)
        self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)

    def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
        """Get % sparse percentile of the Gaussian distribution."""
        # NOTE(rob): for TP>1, we could all-gather to get the means/std.
        # But we do not do this because in expectation they are the same
        # and in practice the eval scores are good without gathering.
        mean = torch.mean(x, dim=-1, keepdim=True)
        std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
        cutoff_x = mean + std * self.std_multiplier
        return nn.functional.relu(x - cutoff_x)

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        out = self._gaussian_topk(x[..., :d])
        out = F.gelu(out, approximate=self.approximate)
        return out * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


206
@CustomOp.register("gelu_and_mul")
207
class GeluAndMul(CustomOp):
208
209
210
211
212
213
214
215
216
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

217
218
219
220
221
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
222
223
224
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
zhuwenwen's avatar
zhuwenwen committed
225
                self.op_opt = torch.ops._C.gelu_and_mul_opt
226
227
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
zhuwenwen's avatar
zhuwenwen committed
228
                self.op_opt = torch.ops._C.gelu_tanh_and_mul_opt
229
230
231
232
233
234
        if current_platform.is_rocm() and approximate == "tanh":
            logger.warning_once(
                "[ROCm] PyTorch's native GELU with tanh approximation is unstable "
                "with torch.compile. For native implementation, fallback to 'none' "
                "approximation. The custom kernel implementation is unaffected."
            )
235
236
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
237

238
239
240
241
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
242

243
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
244
        """PyTorch-native implementation equivalent to forward()."""
245
246
247
248
        # TODO: [ROCm] PyTorch's native GELU with tanh is unstable with torch.compile
        approximate = self.approximate
        if current_platform.is_rocm() and approximate == "tanh":
            approximate = "none"
249
        d = x.shape[-1] // 2
250
        return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
251

252
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
253
        d = x.shape[-1] // 2
254
        output_shape = x.shape[:-1] + (d,)
255
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
256
257
258
259
        if envs.VLLM_USE_OPT_OP: 
            self.op_opt(out, x)
        else:
            self.op(out, x)
260
261
        return out

262
263
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
264
        output_shape = x.shape[:-1] + (d,)
265
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
266
        self.op(out, x)
267
268
        return out

269
    def extra_repr(self) -> str:
270
        return f"approximate={repr(self.approximate)}"
271

272

273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
@CustomOp.register("swigluoai_and_mul")
class SwigluOAIAndMul(CustomOp):
    # https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
    def __init__(self, alpha: float = 1.702, limit: float = 7.0):
        super().__init__()
        self.alpha = alpha
        self.limit = limit

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""

        gate, up = x[..., ::2], x[..., 1::2]
        gate = gate.clamp(min=None, max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        glu = gate * torch.sigmoid(gate * self.alpha)
        gated_output = (up + 1) * glu
        return gated_output

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
293
        output_shape = x.shape[:-1] + (d,)
294
295
296
297
298
299
300
301
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
        return out

    def extra_repr(self) -> str:
        return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"


302
@CustomOp.register("gelu_new")
303
class NewGELU(CustomOp):
304
305
306
307
308
309
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
310

311
312
            self.op = ipex_ops.gelu_new

313
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
314
315
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
316
        return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
317

318
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
319
        out = torch.empty_like(x)
320
        self.op(out, x)
321
322
        return out

323
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
324
        return self.op(x)
325

326

327
@CustomOp.register("gelu_fast")
328
class FastGELU(CustomOp):
329
330
331
332
333
334
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
335

336
337
            self.op = ipex_ops.gelu_fast

338
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
339
        """PyTorch-native implementation equivalent to forward()."""
340
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
341

342
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
343
        out = torch.empty_like(x)
344
        self.op(out, x)
345
346
        return out

347
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
348
        return self.op(x)
349

350

351
@CustomOp.register("quick_gelu")
352
353
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
354
355
356
357
358
359
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
360

361
362
            self.op = ipex_ops.gelu_quick

363
364
365
366
367
368
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
369
        self.op(out, x)
370
371
        return out

372
373
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
374
        self.op(out, x)
375
376
        return out

377
378
379
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

380

381
@CustomOp.register("relu2")
382
383
384
385
386
387
388
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
389
        return torch.square(F.relu(x))
390
391

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
392
        # TODO : implement cuda kernels
393
394
395
        return self.forward_native(x)


396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
@CustomOp.register("xielu")
class XIELU(CustomOp):
    """
    Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
    If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
    Otherwise, we emit a single warning and use xIELU Python
    """

    def __init__(
        self,
        alpha_p_init: float = 0.8,
        alpha_n_init: float = 0.8,
        beta: float = 0.5,
        eps: float = -1e-6,
        dtype: torch.dtype = torch.bfloat16,
        with_vector_loads: bool = False,
    ):
        super().__init__()
        self.alpha_p = nn.Parameter(
415
416
417
418
            torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
                0
            )
        )
419
420
        self.alpha_n = nn.Parameter(
            torch.log(
421
422
423
                torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
            ).unsqueeze(0)
        )
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
        self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
        self.with_vector_loads = with_vector_loads
        # Temporary until xIELU CUDA fully implemented
        self._beta_scalar = float(self.beta.detach().cpu().float().item())
        self._eps_scalar = float(self.eps.detach().cpu().float().item())

        self._xielu_cuda_obj = None
        try:
            import xielu.ops  # noqa: F401

            self._xielu_cuda_obj = torch.classes.xielu.XIELU()
            msg = "Using experimental xIELU CUDA."
            try:
                from torch._dynamo import allow_in_graph

                self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
                msg += " Enabled torch._dynamo for xIELU CUDA."
            except Exception as err:
443
444
445
446
                msg += (
                    f" Could not enable torch._dynamo for xIELU ({err}) - "
                    "this may result in slower performance."
                )
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
                self._xielu_cuda_fn = self._xielu_cuda
            logger.warning_once(msg)
        except Exception as err:
            logger.warning_once(
                "CUDA-fused xIELU not available (%s) –"
                " falling back to a Python version.\n"
                "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
                str(err),
            )

    def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
        alpha_p = nn.functional.softplus(self.alpha_p)
        alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
        return torch.where(
            x > 0,
            alpha_p * x * x + self.beta * x,
463
            (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
464
465
466
467
        )

    def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
        """Firewall function to prevent torch.compile from seeing .item()"""
468
        assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
        original_shape = x.shape
        # CUDA kernel expects 3D tensors, reshape if needed
        while x.dim() < 3:
            x = x.unsqueeze(0)
        if x.dim() > 3:
            x = x.view(-1, 1, x.size(-1))
        if original_shape != x.shape:
            logger.warning_once(
                "Warning: xIELU input tensor expects 3 dimensions"
                " but got (shape: %s). Reshaping to (shape: %s).",
                original_shape,
                x.shape,
            )
        result = self._xielu_cuda_obj.forward(
            x,
            self.alpha_p,
            self.alpha_n,
            # Temporary until xIELU CUDA fully implemented ->
            # self.{beta,eps}.item()
            self._beta_scalar,
            self._eps_scalar,
            self.with_vector_loads,
        )
        return result.view(original_shape)

494
    def forward_native(self, input: torch.Tensor) -> torch.Tensor:
495
496
497
498
499
500
501
502
503
        if self._xielu_cuda_obj is not None and input.is_cuda:
            if not torch._dynamo.is_compiling():
                return self._xielu_cuda_fn(input)
            else:
                logger.warning_once(
                    "torch._dynamo is compiling, using Python version of xIELU."
                )
        return self._xielu_python(input)

504
505
506
    def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
        return self.forward_native(input)

507

508
509
510
511
512
513
514
515
516
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
517
518
        intermediate_size: int,
        input_is_parallel: bool = True,
519
        params_dtype: torch.dtype | None = None,
520
521
522
    ):
        super().__init__()
        self.act = act_module
523
        self.input_is_parallel = input_is_parallel
524
525
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
526
            intermediate_size_per_partition = divide(intermediate_size, tp_size)
527
528
529
530
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
531
        self.scales = nn.Parameter(
532
533
            torch.empty(intermediate_size_per_partition, dtype=params_dtype)
        )
534
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
535

536
    def forward(self, x: torch.Tensor) -> torch.Tensor:
537
538
        return self.act(x) / self.scales

539
540
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
541
542
543
544
545
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
546
547
548
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

549

550
551
552
553
554
_ACTIVATION_REGISTRY = LazyDict(
    {
        "gelu": lambda: nn.GELU(),
        "gelu_fast": lambda: FastGELU(),
        "gelu_new": lambda: NewGELU(),
555
556
557
558
559
560
561
562
563
564
        "gelu_pytorch_tanh": lambda: (
            # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
            logger.warning_once(
                "[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
                "Falling back to GELU(approximate='none')."
            ),
            nn.GELU(approximate="none"),
        )[1]
        if current_platform.is_rocm()
        else nn.GELU(approximate="tanh"),
565
566
567
568
569
570
571
572
573
        "relu": lambda: nn.ReLU(),
        "relu2": lambda: ReLUSquaredActivation(),
        "silu": lambda: nn.SiLU(),
        "quick_gelu": lambda: QuickGELU(),
        "tanh": lambda: nn.Tanh(),
        "sigmoid": lambda: nn.Sigmoid(),
        "xielu": lambda: XIELU(),
    }
)
574
575


576
def get_act_fn(act_fn_name: str) -> nn.Module:
577
    """Get an activation function by name."""
578
    act_fn_name = act_fn_name.lower()
579
580
581
582
583
584
585

    if act_fn_name.startswith("torch.nn.modules."):
        activation_name = act_fn_name.split(".")[-1]
        if activation_name == "identity":
            return nn.Identity()
        act_fn_name = activation_name

586
    if act_fn_name not in _ACTIVATION_REGISTRY:
587
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
588

589
    return _ACTIVATION_REGISTRY[act_fn_name]
590
591


592
593
594
595
596
597
598
599
_ACTIVATION_AND_MUL_REGISTRY = LazyDict(
    {
        "gelu": lambda: GeluAndMul(),
        "silu": lambda: SiluAndMul(),
        "geglu": lambda: GeluAndMul(),
        "swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
    }
)
600
601


602
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
603
604
605
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
606
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
607

608
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]