activation.py 24.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4

5
import math
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch
import torch.nn as nn
9
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
13
14
15
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
16
from vllm.logger import init_logger
17
from vllm.model_executor.custom_op import CustomOp
18
from vllm.model_executor.utils import set_weight_attrs
19
from vllm.platforms import current_platform
csy0225's avatar
csy0225 committed
20
from vllm.triton_utils import tl, triton
21
from vllm.utils.collection_utils import LazyDict
Woosuk Kwon's avatar
Woosuk Kwon committed
22

23
24
logger = init_logger(__name__)

Woosuk Kwon's avatar
Woosuk Kwon committed
25

csy0225's avatar
csy0225 committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@triton.jit
def _swiglustep_and_mul_kernel(
    o_ptr,
    o_stride,
    x_ptr,
    x_stride,
    limit: tl.constexpr,
    d: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
) -> None:
    i = tl.program_id(axis=0).to(tl.int64)
    j = tl.program_id(axis=1)
    o_row_ptr = o_ptr + o_stride * i
    x_row_ptr = x_ptr + x_stride * i
    offsets = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < d

    gate = tl.load(x_row_ptr + offsets, mask=mask).to(tl.float32)
    up = tl.load(x_row_ptr + offsets + d, mask=mask).to(tl.float32)

    gate_silu = tl.sigmoid(gate) * gate
    gate_clamped = tl.minimum(gate_silu, limit)
    up_clamped = tl.minimum(tl.maximum(up, -limit), limit)

    result = gate_clamped * up_clamped
    result = result.to(x_ptr.dtype.element_ty)
    tl.store(o_row_ptr + offsets, result, mask=mask)


def swiglustep_and_mul_triton(
    output: torch.Tensor, input: torch.Tensor, limit: float = 7.0
):
    b, n = input.shape
    assert input.ndim == 2
    assert n % 2 == 0
    d = n // 2

    def grid(meta):
        return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))

    _swiglustep_and_mul_kernel[grid](
        output,
        output.stride(0),
        input,
        input.stride(0),
        limit=limit,
        d=d,
        BLOCK_SIZE=1024,
    )


77
# --8<-- [start:fatrelu_and_mul]
78
@CustomOp.register("fatrelu_and_mul")
79
80
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
81

82
83
84
85
86
87
88
89
90
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

91
92
    # --8<-- [end:fatrelu_and_mul]

93
    def __init__(self, threshold: float = 0.0):
94
95
        super().__init__()
        self.threshold = threshold
96
        if current_platform.is_cuda_alike():
97
            self.op = torch.ops._C.fatrelu_and_mul
98
99
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
100
101
102
103
104
105
106
107
108

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
109
        d = x.shape[-1] // 2
110
        output_shape = x.shape[:-1] + (d,)
111
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
112
        self.op(out, x, self.threshold)
113
        return out
114
115


116
# --8<-- [start:silu_and_mul]
117
@CustomOp.register("silu_and_mul")
118
class SiluAndMul(CustomOp):
119
120
    """An activation function for SwiGLU.

121
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
122

123
    Shapes:
124
125
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
126
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
127

128
129
    # --8<-- [end:silu_and_mul]

130
131
    def __init__(self, *, compile_native: bool = True):
        super().__init__(compile_native=compile_native)
132
        if current_platform.is_cuda_alike():
133
134
            self.op = torch.ops._C.silu_and_mul
        elif current_platform.is_xpu():
135
            from vllm._ipex_ops import ipex_ops
136

137
            self.op = ipex_ops.silu_and_mul
138
139
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
140

141
142
    @staticmethod
    def forward_native(x: torch.Tensor) -> torch.Tensor:
143
144
145
146
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

147
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
148
        d = x.shape[-1] // 2
149
        output_shape = x.shape[:-1] + (d,)
150
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
151
        self.op(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
152
        return out
153

154
155
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
156
        output_shape = x.shape[:-1] + (d,)
157
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
158
        self.op(out, x)
159
160
        return out

161

162
# --8<-- [start:mul_and_silu]
163
164
165
166
167
168
169
170
171
172
173
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.

    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

174
175
    # --8<-- [end:mul_and_silu]

176
177
    def __init__(self):
        super().__init__()
178
        if current_platform.is_cuda_alike():
179
180
181
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
182

183
            self.op = ipex_ops.silu_and_mul
184
185
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
186
187
188
189
190
191
192
193

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
194
        output_shape = x.shape[:-1] + (d,)
195
196
197
198
199
200
201
202
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        self.op(out, x)
        return out

    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:


203
# --8<-- [start:gelu_and_mul_sparse]
Robert Shaw's avatar
Robert Shaw committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
    """An activation function for GeluAndMulSparse.
    This activation function is used in Gemma3n. It computes:
        up_proj = self.up_proj(x)
        gate_proj = self.gate_proj(x)
        gate_proj = self._gaussian_topk(gate_proj) # sparsity
        activations = self.act_fn(gate_proj) # gelu
        down_proj = self.down_proj(activations * up_proj)
    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

218
219
    # --8<-- [end:gelu_and_mul_sparse]

Robert Shaw's avatar
Robert Shaw committed
220
221
222
223
224
225
    def __init__(self, activation_sparsity: float, approximate: str = "none"):
        super().__init__()
        # Gelu.
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
226
227
228
229
230
231
232
        if current_platform.is_rocm() and approximate == "tanh":
            # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
            logger.warning_once(
                "[ROCm] Pytorch's native GELU with tanh approximation is currently "
                "unstable and produces garbage. Fallback to 'none' approximation."
            )
            self.approximate = "none"
Robert Shaw's avatar
Robert Shaw committed
233
234
235

        # Sparsity.
        if activation_sparsity == 0.0:
236
237
            raise ValueError("activation_sparsity is 0.0. Please use GeluAndMul.")
        target_sparsity_tensor = torch.tensor(activation_sparsity, dtype=torch.float32)
Robert Shaw's avatar
Robert Shaw committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        normal_dist = torch.distributions.normal.Normal(0, 1)
        self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)

    def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
        """Get % sparse percentile of the Gaussian distribution."""
        # NOTE(rob): for TP>1, we could all-gather to get the means/std.
        # But we do not do this because in expectation they are the same
        # and in practice the eval scores are good without gathering.
        mean = torch.mean(x, dim=-1, keepdim=True)
        std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
        cutoff_x = mean + std * self.std_multiplier
        return nn.functional.relu(x - cutoff_x)

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        out = self._gaussian_topk(x[..., :d])
        out = F.gelu(out, approximate=self.approximate)
        return out * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


262
# --8<-- [start:gelu_and_mul]
263
@CustomOp.register("gelu_and_mul")
264
class GeluAndMul(CustomOp):
265
266
267
268
269
270
271
272
273
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

274
275
    # --8<-- [end:gelu_and_mul]

276
277
278
279
280
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
281
282
283
284
285
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
286
287
288
289
290
291
        if current_platform.is_rocm() and approximate == "tanh":
            logger.warning_once(
                "[ROCm] PyTorch's native GELU with tanh approximation is unstable "
                "with torch.compile. For native implementation, fallback to 'none' "
                "approximation. The custom kernel implementation is unaffected."
            )
292
293
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
294

295
296
297
298
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
299

300
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
301
        """PyTorch-native implementation equivalent to forward()."""
302
303
304
305
        # TODO: [ROCm] PyTorch's native GELU with tanh is unstable with torch.compile
        approximate = self.approximate
        if current_platform.is_rocm() and approximate == "tanh":
            approximate = "none"
306
        d = x.shape[-1] // 2
307
        return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
308

309
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
310
        d = x.shape[-1] // 2
311
        output_shape = x.shape[:-1] + (d,)
312
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
313
        self.op(out, x)
314
315
        return out

316
317
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
318
        output_shape = x.shape[:-1] + (d,)
319
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
320
        self.op(out, x)
321
322
        return out

323
    def extra_repr(self) -> str:
324
        return f"approximate={repr(self.approximate)}"
325

326

327
# --8<-- [start:swigluoai_and_mul]
328
329
330
@CustomOp.register("swigluoai_and_mul")
class SwigluOAIAndMul(CustomOp):
    # https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
331
332
    # --8<-- [end:swigluoai_and_mul]

333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    def __init__(self, alpha: float = 1.702, limit: float = 7.0):
        super().__init__()
        self.alpha = alpha
        self.limit = limit

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""

        gate, up = x[..., ::2], x[..., 1::2]
        gate = gate.clamp(min=None, max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        glu = gate * torch.sigmoid(gate * self.alpha)
        gated_output = (up + 1) * glu
        return gated_output

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
350
        output_shape = x.shape[:-1] + (d,)
351
352
353
354
355
356
357
358
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
        return out

    def extra_repr(self) -> str:
        return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"


csy0225's avatar
csy0225 committed
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
# --8<-- [start:swiglustep_and_mul]
@CustomOp.register("swiglustep_and_mul")
class SwigluStepAndMul(CustomOp):
    """An activation function for SwiGLU with clamping.

    Computes x -> silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit)
    where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, limit: float = 7.0):
        super().__init__()
        if limit is None:
            raise ValueError("SwigluStepAndMul requires limit to be set.")
        self.limit = limit

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        gate, up = x.chunk(2, dim=-1)
        gate = F.silu(gate)
        gate = gate.clamp(max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        return gate * up

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = x.shape[:-1] + (d,)
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        swiglustep_and_mul_triton(out, x, self.limit)
        return out

    def extra_repr(self) -> str:
        return f"limit={repr(self.limit)}"


397
# --8<-- [start:gelu_new]
398
@CustomOp.register("gelu_new")
399
class NewGELU(CustomOp):
400
401
    # --8<-- [end:gelu_new]

402
403
404
405
406
407
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
408

409
410
            self.op = ipex_ops.gelu_new

411
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
412
413
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
414
        return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
415

416
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
417
        out = torch.empty_like(x)
418
        self.op(out, x)
419
420
        return out

421
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
422
        return self.op(x)
423

424

425
# --8<-- [start:gelu_fast]
426
@CustomOp.register("gelu_fast")
427
class FastGELU(CustomOp):
428
429
    # --8<-- [end:gelu_fast]

430
431
432
433
434
435
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
436

437
438
            self.op = ipex_ops.gelu_fast

439
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
440
        """PyTorch-native implementation equivalent to forward()."""
441
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
442

443
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
444
        out = torch.empty_like(x)
445
        self.op(out, x)
446
447
        return out

448
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
449
        return self.op(x)
450

451

452
# --8<-- [start:quick_gelu]
453
@CustomOp.register("quick_gelu")
454
455
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
456
457
    # --8<-- [end:quick_gelu]

458
459
460
461
462
463
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
464

465
466
            self.op = ipex_ops.gelu_quick

467
468
469
470
471
472
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
473
        self.op(out, x)
474
475
        return out

476
477
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
478
        self.op(out, x)
479
480
        return out

481
482
483
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

484

485
# --8<-- [start:relu2]
486
@CustomOp.register("relu2")
487
488
489
490
491
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

492
493
    # --8<-- [end:relu2]

494
495
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
496
        return torch.square(F.relu(x))
497
498

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
499
        # TODO : implement cuda kernels
500
501
502
        return self.forward_native(x)


503
# --8<-- [start:xielu]
504
505
506
507
508
509
510
511
@CustomOp.register("xielu")
class XIELU(CustomOp):
    """
    Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
    If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
    Otherwise, we emit a single warning and use xIELU Python
    """

512
513
    # --8<-- [end:xielu]

514
515
516
517
518
519
520
521
522
523
524
    def __init__(
        self,
        alpha_p_init: float = 0.8,
        alpha_n_init: float = 0.8,
        beta: float = 0.5,
        eps: float = -1e-6,
        dtype: torch.dtype = torch.bfloat16,
        with_vector_loads: bool = False,
    ):
        super().__init__()
        self.alpha_p = nn.Parameter(
525
526
527
528
            torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
                0
            )
        )
529
530
        self.alpha_n = nn.Parameter(
            torch.log(
531
532
533
                torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
            ).unsqueeze(0)
        )
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
        self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
        self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
        self.with_vector_loads = with_vector_loads
        # Temporary until xIELU CUDA fully implemented
        self._beta_scalar = float(self.beta.detach().cpu().float().item())
        self._eps_scalar = float(self.eps.detach().cpu().float().item())

        self._xielu_cuda_obj = None
        try:
            import xielu.ops  # noqa: F401

            self._xielu_cuda_obj = torch.classes.xielu.XIELU()
            msg = "Using experimental xIELU CUDA."
            try:
                from torch._dynamo import allow_in_graph

                self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
                msg += " Enabled torch._dynamo for xIELU CUDA."
            except Exception as err:
553
554
555
556
                msg += (
                    f" Could not enable torch._dynamo for xIELU ({err}) - "
                    "this may result in slower performance."
                )
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
                self._xielu_cuda_fn = self._xielu_cuda
            logger.warning_once(msg)
        except Exception as err:
            logger.warning_once(
                "CUDA-fused xIELU not available (%s) –"
                " falling back to a Python version.\n"
                "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
                str(err),
            )

    def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
        alpha_p = nn.functional.softplus(self.alpha_p)
        alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
        return torch.where(
            x > 0,
            alpha_p * x * x + self.beta * x,
573
            (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
574
575
576
577
        )

    def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
        """Firewall function to prevent torch.compile from seeing .item()"""
578
        assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
        original_shape = x.shape
        # CUDA kernel expects 3D tensors, reshape if needed
        while x.dim() < 3:
            x = x.unsqueeze(0)
        if x.dim() > 3:
            x = x.view(-1, 1, x.size(-1))
        if original_shape != x.shape:
            logger.warning_once(
                "Warning: xIELU input tensor expects 3 dimensions"
                " but got (shape: %s). Reshaping to (shape: %s).",
                original_shape,
                x.shape,
            )
        result = self._xielu_cuda_obj.forward(
            x,
            self.alpha_p,
            self.alpha_n,
            # Temporary until xIELU CUDA fully implemented ->
            # self.{beta,eps}.item()
            self._beta_scalar,
            self._eps_scalar,
            self.with_vector_loads,
        )
        return result.view(original_shape)

604
    def forward_native(self, input: torch.Tensor) -> torch.Tensor:
605
606
607
608
609
610
611
612
613
        if self._xielu_cuda_obj is not None and input.is_cuda:
            if not torch._dynamo.is_compiling():
                return self._xielu_cuda_fn(input)
            else:
                logger.warning_once(
                    "torch._dynamo is compiling, using Python version of xIELU."
                )
        return self._xielu_python(input)

614
615
616
    def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
        return self.forward_native(input)

617

618
619
620
621
622
623
624
625
626
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
627
628
        intermediate_size: int,
        input_is_parallel: bool = True,
629
        params_dtype: torch.dtype | None = None,
630
631
632
    ):
        super().__init__()
        self.act = act_module
633
        self.input_is_parallel = input_is_parallel
634
635
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
636
            intermediate_size_per_partition = divide(intermediate_size, tp_size)
637
638
639
640
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
641
        self.scales = nn.Parameter(
642
643
            torch.empty(intermediate_size_per_partition, dtype=params_dtype)
        )
644
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
645

646
    def forward(self, x: torch.Tensor) -> torch.Tensor:
647
648
        return self.act(x) / self.scales

649
650
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
651
652
653
654
655
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
656
657
658
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

659

660
661
662
663
664
_ACTIVATION_REGISTRY = LazyDict(
    {
        "gelu": lambda: nn.GELU(),
        "gelu_fast": lambda: FastGELU(),
        "gelu_new": lambda: NewGELU(),
665
666
667
668
669
670
671
672
673
674
        "gelu_pytorch_tanh": lambda: (
            # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
            logger.warning_once(
                "[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
                "Falling back to GELU(approximate='none')."
            ),
            nn.GELU(approximate="none"),
        )[1]
        if current_platform.is_rocm()
        else nn.GELU(approximate="tanh"),
675
676
677
678
679
680
681
682
683
        "relu": lambda: nn.ReLU(),
        "relu2": lambda: ReLUSquaredActivation(),
        "silu": lambda: nn.SiLU(),
        "quick_gelu": lambda: QuickGELU(),
        "tanh": lambda: nn.Tanh(),
        "sigmoid": lambda: nn.Sigmoid(),
        "xielu": lambda: XIELU(),
    }
)
684
685


686
def get_act_fn(act_fn_name: str) -> nn.Module:
687
    """Get an activation function by name."""
688
    act_fn_name = act_fn_name.lower()
689
690
691
692
693
694
695

    if act_fn_name.startswith("torch.nn.modules."):
        activation_name = act_fn_name.split(".")[-1]
        if activation_name == "identity":
            return nn.Identity()
        act_fn_name = activation_name

696
    if act_fn_name not in _ACTIVATION_REGISTRY:
697
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
698

699
    return _ACTIVATION_REGISTRY[act_fn_name]
700
701


702
703
704
705
706
707
708
709
_ACTIVATION_AND_MUL_REGISTRY = LazyDict(
    {
        "gelu": lambda: GeluAndMul(),
        "silu": lambda: SiluAndMul(),
        "geglu": lambda: GeluAndMul(),
        "swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
    }
)
710
711


712
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
713
714
715
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
716
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
717

718
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]