activation.py 22.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4

5
import math
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch
import torch.nn as nn
9
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
13
14
15
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
16
from vllm.logger import init_logger
17
from vllm.model_executor.custom_op import CustomOp
18
from vllm.model_executor.utils import set_weight_attrs
19
from vllm.platforms import current_platform
20
from vllm.utils.collection_utils import LazyDict
zhuwenwen's avatar
zhuwenwen committed
21
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
22

23
24
logger = init_logger(__name__)

Woosuk Kwon's avatar
Woosuk Kwon committed
25

26
# --8<-- [start:fatrelu_and_mul]
27
@CustomOp.register("fatrelu_and_mul")
28
29
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
30

31
32
33
34
35
36
37
38
39
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

40
41
    # --8<-- [end:fatrelu_and_mul]

42
    def __init__(self, threshold: float = 0.0):
43
44
        super().__init__()
        self.threshold = threshold
45
        if current_platform.is_cuda_alike():
46
            self.op = torch.ops._C.fatrelu_and_mul
47
48
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
49
50
51
52
53
54
55
56
57

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
58
        d = x.shape[-1] // 2
59
        output_shape = x.shape[:-1] + (d,)
60
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
61
        self.op(out, x, self.threshold)
62
        return out
63
64


65
# --8<-- [start:silu_and_mul]
66
@CustomOp.register("silu_and_mul")
67
class SiluAndMul(CustomOp):
68
69
    """An activation function for SwiGLU.

70
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
71

72
    Shapes:
73
74
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
75
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
76

77
78
    # --8<-- [end:silu_and_mul]

79
80
    def __init__(self):
        super().__init__()
81
        if current_platform.is_cuda_alike():
82
            self.op = torch.ops._C.silu_and_mul
zhuwenwen's avatar
zhuwenwen committed
83
            self.op_opt = torch.ops._C.silu_and_mul_opt
84
        elif current_platform.is_xpu():
85
            from vllm._ipex_ops import ipex_ops
86

87
            self.op = ipex_ops.silu_and_mul
88
89
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
90

91
92
    @staticmethod
    def forward_native(x: torch.Tensor) -> torch.Tensor:
93
        """PyTorch-native implementation equivalent to forward()."""
zhuwenwen's avatar
zhuwenwen committed
94
95
96
97
        # if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
        #     return self.forward_cuda(x)
        # else:
        d = x.shape[-1] // 2
98
99
        return F.silu(x[..., :d]) * x[..., d:]

100
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
101
        d = x.shape[-1] // 2
102
        output_shape = x.shape[:-1] + (d,)
103
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
104
        if envs.VLLM_USE_OPT_OP:
zhuwenwen's avatar
zhuwenwen committed
105
            self.op_opt(out, x)
zhuwenwen's avatar
zhuwenwen committed
106
        else:
zhuwenwen's avatar
zhuwenwen committed
107
            self.op(out, x) 
Woosuk Kwon's avatar
Woosuk Kwon committed
108
        return out
109

110
111
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
112
        output_shape = x.shape[:-1] + (d,)
113
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
114
        self.op(out, x)
115
116
        return out

117

118
# --8<-- [start:mul_and_silu]
119
120
121
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.
122

123
124
125
126
127
128
129
    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

130
131
    # --8<-- [end:mul_and_silu]

132
133
    def __init__(self):
        super().__init__()
134
        if current_platform.is_cuda_alike():
135
136
137
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
138

139
            self.op = ipex_ops.silu_and_mul
140
141
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
142
143
144
145
146
147
148

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
149
        d = x.shape[-1] // 2
150
        output_shape = x.shape[:-1] + (d,)
151
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
152
        self.op(out, x)
153
154
        return out

155
156
157
    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

158

159
# --8<-- [start:gelu_and_mul_sparse]
Robert Shaw's avatar
Robert Shaw committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
    """An activation function for GeluAndMulSparse.
    This activation function is used in Gemma3n. It computes:
        up_proj = self.up_proj(x)
        gate_proj = self.gate_proj(x)
        gate_proj = self._gaussian_topk(gate_proj) # sparsity
        activations = self.act_fn(gate_proj) # gelu
        down_proj = self.down_proj(activations * up_proj)
    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

174
175
    # --8<-- [end:gelu_and_mul_sparse]

Robert Shaw's avatar
Robert Shaw committed
176
177
178
179
180
181
    def __init__(self, activation_sparsity: float, approximate: str = "none"):
        super().__init__()
        # Gelu.
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
182
183
184
185
186
187
188
        if current_platform.is_rocm() and approximate == "tanh":
            # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
            logger.warning_once(
                "[ROCm] Pytorch's native GELU with tanh approximation is currently "
                "unstable and produces garbage. Fallback to 'none' approximation."
            )
            self.approximate = "none"
Robert Shaw's avatar
Robert Shaw committed
189
190
191

        # Sparsity.
        if activation_sparsity == 0.0:
192
193
            raise ValueError("activation_sparsity is 0.0. Please use GeluAndMul.")
        target_sparsity_tensor = torch.tensor(activation_sparsity, dtype=torch.float32)
Robert Shaw's avatar
Robert Shaw committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        normal_dist = torch.distributions.normal.Normal(0, 1)
        self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)

    def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
        """Get % sparse percentile of the Gaussian distribution."""
        # NOTE(rob): for TP>1, we could all-gather to get the means/std.
        # But we do not do this because in expectation they are the same
        # and in practice the eval scores are good without gathering.
        mean = torch.mean(x, dim=-1, keepdim=True)
        std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
        cutoff_x = mean + std * self.std_multiplier
        return nn.functional.relu(x - cutoff_x)

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        out = self._gaussian_topk(x[..., :d])
        out = F.gelu(out, approximate=self.approximate)
        return out * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


218
# --8<-- [start:gelu_and_mul]
219
@CustomOp.register("gelu_and_mul")
220
class GeluAndMul(CustomOp):
221
222
223
224
225
226
227
228
229
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

230
231
    # --8<-- [end:gelu_and_mul]

232
233
234
235
236
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
237
238
239
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
zhuwenwen's avatar
zhuwenwen committed
240
                self.op_opt = torch.ops._C.gelu_and_mul_opt
241
242
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
zhuwenwen's avatar
zhuwenwen committed
243
                self.op_opt = torch.ops._C.gelu_tanh_and_mul_opt
244
245
246
247
248
249
        if current_platform.is_rocm() and approximate == "tanh":
            logger.warning_once(
                "[ROCm] PyTorch's native GELU with tanh approximation is unstable "
                "with torch.compile. For native implementation, fallback to 'none' "
                "approximation. The custom kernel implementation is unaffected."
            )
250
251
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
252

253
254
255
256
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
257

258
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
259
        """PyTorch-native implementation equivalent to forward()."""
260
261
262
263
        # TODO: [ROCm] PyTorch's native GELU with tanh is unstable with torch.compile
        approximate = self.approximate
        if current_platform.is_rocm() and approximate == "tanh":
            approximate = "none"
264
        d = x.shape[-1] // 2
265
        return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
266

267
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
268
        d = x.shape[-1] // 2
269
        output_shape = x.shape[:-1] + (d,)
270
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
271
272
273
274
        if envs.VLLM_USE_OPT_OP: 
            self.op_opt(out, x)
        else:
            self.op(out, x)
275
276
        return out

277
278
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
279
        output_shape = x.shape[:-1] + (d,)
280
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
281
        self.op(out, x)
282
283
        return out

284
    def extra_repr(self) -> str:
285
        return f"approximate={repr(self.approximate)}"
286

287

288
# --8<-- [start:swigluoai_and_mul]
289
290
291
@CustomOp.register("swigluoai_and_mul")
class SwigluOAIAndMul(CustomOp):
    # https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
292
293
    # --8<-- [end:swigluoai_and_mul]

294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    def __init__(self, alpha: float = 1.702, limit: float = 7.0):
        super().__init__()
        self.alpha = alpha
        self.limit = limit

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""

        gate, up = x[..., ::2], x[..., 1::2]
        gate = gate.clamp(min=None, max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        glu = gate * torch.sigmoid(gate * self.alpha)
        gated_output = (up + 1) * glu
        return gated_output

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
311
        output_shape = x.shape[:-1] + (d,)
312
313
314
315
316
317
318
319
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
        return out

    def extra_repr(self) -> str:
        return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"


320
# --8<-- [start:gelu_new]
321
@CustomOp.register("gelu_new")
322
class NewGELU(CustomOp):
323
324
    # --8<-- [end:gelu_new]

325
326
327
328
329
330
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
331

332
333
            self.op = ipex_ops.gelu_new

334
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
335
336
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
337
        return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
338

339
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
340
        out = torch.empty_like(x)
341
        self.op(out, x)
342
343
        return out

344
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
345
        return self.op(x)
346

347

348
# --8<-- [start:gelu_fast]
349
@CustomOp.register("gelu_fast")
350
class FastGELU(CustomOp):
351
352
    # --8<-- [end:gelu_fast]

353
354
355
356
357
358
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
359

360
361
            self.op = ipex_ops.gelu_fast

362
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
363
        """PyTorch-native implementation equivalent to forward()."""
364
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
365

366
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
367
        out = torch.empty_like(x)
368
        self.op(out, x)
369
370
        return out

371
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
372
        return self.op(x)
373

374

375
# --8<-- [start:quick_gelu]
376
@CustomOp.register("quick_gelu")
377
378
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
379
380
    # --8<-- [end:quick_gelu]

381
382
383
384
385
386
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
387

388
389
            self.op = ipex_ops.gelu_quick

390
391
392
393
394
395
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
396
        self.op(out, x)
397
398
        return out

399
400
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
401
        self.op(out, x)
402
403
        return out

404
405
406
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

407

408
# --8<-- [start:relu2]
409
@CustomOp.register("relu2")
410
411
412
413
414
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

415
416
    # --8<-- [end:relu2]

417
418
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
419
        return torch.square(F.relu(x))
420
421

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
422
        # TODO : implement cuda kernels
423
424
425
        return self.forward_native(x)


426
# --8<-- [start:xielu]
427
428
429
430
431
432
433
434
@CustomOp.register("xielu")
class XIELU(CustomOp):
    """
    Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
    If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
    Otherwise, we emit a single warning and use xIELU Python
    """

435
436
    # --8<-- [end:xielu]

437
438
439
440
441
442
443
444
445
446
447
    def __init__(
        self,
        alpha_p_init: float = 0.8,
        alpha_n_init: float = 0.8,
        beta: float = 0.5,
        eps: float = -1e-6,
        dtype: torch.dtype = torch.bfloat16,
        with_vector_loads: bool = False,
    ):
        super().__init__()
        self.alpha_p = nn.Parameter(
448
449
450
451
            torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
                0
            )
        )
452
453
        self.alpha_n = nn.Parameter(
            torch.log(
454
455
456
                torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
            ).unsqueeze(0)
        )
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
        self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
        self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
        self.with_vector_loads = with_vector_loads
        # Temporary until xIELU CUDA fully implemented
        self._beta_scalar = float(self.beta.detach().cpu().float().item())
        self._eps_scalar = float(self.eps.detach().cpu().float().item())

        self._xielu_cuda_obj = None
        try:
            import xielu.ops  # noqa: F401

            self._xielu_cuda_obj = torch.classes.xielu.XIELU()
            msg = "Using experimental xIELU CUDA."
            try:
                from torch._dynamo import allow_in_graph

                self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
                msg += " Enabled torch._dynamo for xIELU CUDA."
            except Exception as err:
476
477
478
479
                msg += (
                    f" Could not enable torch._dynamo for xIELU ({err}) - "
                    "this may result in slower performance."
                )
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
                self._xielu_cuda_fn = self._xielu_cuda
            logger.warning_once(msg)
        except Exception as err:
            logger.warning_once(
                "CUDA-fused xIELU not available (%s) –"
                " falling back to a Python version.\n"
                "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
                str(err),
            )

    def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
        alpha_p = nn.functional.softplus(self.alpha_p)
        alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
        return torch.where(
            x > 0,
            alpha_p * x * x + self.beta * x,
496
            (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
497
498
499
500
        )

    def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
        """Firewall function to prevent torch.compile from seeing .item()"""
501
        assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
        original_shape = x.shape
        # CUDA kernel expects 3D tensors, reshape if needed
        while x.dim() < 3:
            x = x.unsqueeze(0)
        if x.dim() > 3:
            x = x.view(-1, 1, x.size(-1))
        if original_shape != x.shape:
            logger.warning_once(
                "Warning: xIELU input tensor expects 3 dimensions"
                " but got (shape: %s). Reshaping to (shape: %s).",
                original_shape,
                x.shape,
            )
        result = self._xielu_cuda_obj.forward(
            x,
            self.alpha_p,
            self.alpha_n,
            # Temporary until xIELU CUDA fully implemented ->
            # self.{beta,eps}.item()
            self._beta_scalar,
            self._eps_scalar,
            self.with_vector_loads,
        )
        return result.view(original_shape)

527
    def forward_native(self, input: torch.Tensor) -> torch.Tensor:
528
529
530
531
532
533
534
535
536
        if self._xielu_cuda_obj is not None and input.is_cuda:
            if not torch._dynamo.is_compiling():
                return self._xielu_cuda_fn(input)
            else:
                logger.warning_once(
                    "torch._dynamo is compiling, using Python version of xIELU."
                )
        return self._xielu_python(input)

537
538
539
    def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
        return self.forward_native(input)

540

541
542
543
544
545
546
547
548
549
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
550
551
        intermediate_size: int,
        input_is_parallel: bool = True,
552
        params_dtype: torch.dtype | None = None,
553
554
555
    ):
        super().__init__()
        self.act = act_module
556
        self.input_is_parallel = input_is_parallel
557
558
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
559
            intermediate_size_per_partition = divide(intermediate_size, tp_size)
560
561
562
563
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
564
        self.scales = nn.Parameter(
565
566
            torch.empty(intermediate_size_per_partition, dtype=params_dtype)
        )
567
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
568

569
    def forward(self, x: torch.Tensor) -> torch.Tensor:
570
571
        return self.act(x) / self.scales

572
573
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
574
575
576
577
578
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
579
580
581
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

582

583
584
585
586
587
_ACTIVATION_REGISTRY = LazyDict(
    {
        "gelu": lambda: nn.GELU(),
        "gelu_fast": lambda: FastGELU(),
        "gelu_new": lambda: NewGELU(),
588
589
590
591
592
593
594
595
596
597
        "gelu_pytorch_tanh": lambda: (
            # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
            logger.warning_once(
                "[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
                "Falling back to GELU(approximate='none')."
            ),
            nn.GELU(approximate="none"),
        )[1]
        if current_platform.is_rocm()
        else nn.GELU(approximate="tanh"),
598
599
600
601
602
603
604
605
606
        "relu": lambda: nn.ReLU(),
        "relu2": lambda: ReLUSquaredActivation(),
        "silu": lambda: nn.SiLU(),
        "quick_gelu": lambda: QuickGELU(),
        "tanh": lambda: nn.Tanh(),
        "sigmoid": lambda: nn.Sigmoid(),
        "xielu": lambda: XIELU(),
    }
)
607
608


609
def get_act_fn(act_fn_name: str) -> nn.Module:
610
    """Get an activation function by name."""
611
    act_fn_name = act_fn_name.lower()
612
613
614
615
616
617
618

    if act_fn_name.startswith("torch.nn.modules."):
        activation_name = act_fn_name.split(".")[-1]
        if activation_name == "identity":
            return nn.Identity()
        act_fn_name = activation_name

619
    if act_fn_name not in _ACTIVATION_REGISTRY:
620
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
621

622
    return _ACTIVATION_REGISTRY[act_fn_name]
623
624


625
626
627
628
629
630
631
632
_ACTIVATION_AND_MUL_REGISTRY = LazyDict(
    {
        "gelu": lambda: GeluAndMul(),
        "silu": lambda: SiluAndMul(),
        "geglu": lambda: GeluAndMul(),
        "swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
    }
)
633
634


635
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
636
637
638
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
639
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
640

641
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]