"tests/vscode:/vscode.git/clone" did not exist on "109e15a335a20251cbefa0a81bf51cd7624eae27"
activation.py 26.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4

5
import math
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch
import torch.nn as nn
9
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
13
14
15
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
16
from vllm.logger import init_logger
17
from vllm.model_executor.custom_op import CustomOp
18
from vllm.model_executor.utils import set_weight_attrs
19
from vllm.platforms import CpuArchEnum, current_platform
csy0225's avatar
csy0225 committed
20
from vllm.triton_utils import tl, triton
21
from vllm.utils.collection_utils import LazyDict
Woosuk Kwon's avatar
Woosuk Kwon committed
22

23
24
logger = init_logger(__name__)

Woosuk Kwon's avatar
Woosuk Kwon committed
25

csy0225's avatar
csy0225 committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@triton.jit
def _swiglustep_and_mul_kernel(
    o_ptr,
    o_stride,
    x_ptr,
    x_stride,
    limit: tl.constexpr,
    d: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
) -> None:
    i = tl.program_id(axis=0).to(tl.int64)
    j = tl.program_id(axis=1)
    o_row_ptr = o_ptr + o_stride * i
    x_row_ptr = x_ptr + x_stride * i
    offsets = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < d

    gate = tl.load(x_row_ptr + offsets, mask=mask).to(tl.float32)
    up = tl.load(x_row_ptr + offsets + d, mask=mask).to(tl.float32)

    gate_silu = tl.sigmoid(gate) * gate
    gate_clamped = tl.minimum(gate_silu, limit)
    up_clamped = tl.minimum(tl.maximum(up, -limit), limit)

    result = gate_clamped * up_clamped
    result = result.to(x_ptr.dtype.element_ty)
    tl.store(o_row_ptr + offsets, result, mask=mask)


def swiglustep_and_mul_triton(
    output: torch.Tensor, input: torch.Tensor, limit: float = 7.0
):
    b, n = input.shape
    assert input.ndim == 2
    assert n % 2 == 0
    d = n // 2

    def grid(meta):
        return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))

    _swiglustep_and_mul_kernel[grid](
        output,
        output.stride(0),
        input,
        input.stride(0),
        limit=limit,
        d=d,
        BLOCK_SIZE=1024,
    )


77
# --8<-- [start:fatrelu_and_mul]
78
@CustomOp.register("fatrelu_and_mul")
79
80
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
81

82
83
84
85
86
87
88
89
90
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

91
92
    # --8<-- [end:fatrelu_and_mul]

93
    def __init__(self, threshold: float = 0.0):
94
95
        super().__init__()
        self.threshold = threshold
96
        if current_platform.is_cuda_alike():
97
            self.op = torch.ops._C.fatrelu_and_mul
98
99
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
100
101
102
103
104
105
106
107
108

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
109
        d = x.shape[-1] // 2
110
        output_shape = x.shape[:-1] + (d,)
111
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
112
        self.op(out, x, self.threshold)
113
        return out
114
115


116
# --8<-- [start:silu_and_mul]
117
@CustomOp.register("silu_and_mul")
118
class SiluAndMul(CustomOp):
119
120
    """An activation function for SwiGLU.

121
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
122

123
    Shapes:
124
125
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
126
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
127

128
129
    # --8<-- [end:silu_and_mul]

130
131
    def __init__(self, *, compile_native: bool = True):
        super().__init__(compile_native=compile_native)
132
        if current_platform.is_cuda_alike() or current_platform.is_xpu():
133
            self.op = torch.ops._C.silu_and_mul
134
135
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
136

137
138
    @staticmethod
    def forward_native(x: torch.Tensor) -> torch.Tensor:
139
140
141
142
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

143
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
144
        d = x.shape[-1] // 2
145
        output_shape = x.shape[:-1] + (d,)
146
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
147
        self.op(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
148
        return out
149

150
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
151
        return self.forward_cuda(x)
152

153

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
@CustomOp.register("silu_and_mul_with_clamp")
class SiluAndMulWithClamp(CustomOp):
    """SwiGLU activation with input clamping (used by some MoE shared experts).

    Computes:
        gate = clamp(x[..., :d], max=swiglu_limit)
        up   = clamp(x[..., d:], min=-swiglu_limit, max=swiglu_limit)
        out  = silu(gate) * up
    where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, swiglu_limit: float, *, compile_native: bool = True):
        super().__init__(compile_native=compile_native)
        self.swiglu_limit = float(swiglu_limit)
        if current_platform.is_cuda_alike() or current_platform.is_xpu():
            self.op = torch.ops._C.silu_and_mul_with_clamp
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        gate = torch.clamp(x[..., :d], max=self.swiglu_limit)
        up = torch.clamp(x[..., d:], min=-self.swiglu_limit, max=self.swiglu_limit)
        return F.silu(gate) * up

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = x.shape[:-1] + (d,)
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        self.op(out, x, self.swiglu_limit)
        return out

    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_cuda(x)


194
# --8<-- [start:mul_and_silu]
195
196
197
198
199
200
201
202
203
204
205
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.

    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

206
207
    # --8<-- [end:mul_and_silu]

208
209
    def __init__(self):
        super().__init__()
210
        if current_platform.is_cuda_alike() or current_platform.is_xpu():
211
            self.op = torch.ops._C.mul_and_silu
212
213
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
214
215
216
217
218
219
220
221

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
222
        output_shape = x.shape[:-1] + (d,)
223
224
225
226
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        self.op(out, x)
        return out

227
228
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_cuda(x)
229
230


231
# --8<-- [start:gelu_and_mul_sparse]
Robert Shaw's avatar
Robert Shaw committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
    """An activation function for GeluAndMulSparse.
    This activation function is used in Gemma3n. It computes:
        up_proj = self.up_proj(x)
        gate_proj = self.gate_proj(x)
        gate_proj = self._gaussian_topk(gate_proj) # sparsity
        activations = self.act_fn(gate_proj) # gelu
        down_proj = self.down_proj(activations * up_proj)
    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

246
247
    # --8<-- [end:gelu_and_mul_sparse]

Robert Shaw's avatar
Robert Shaw committed
248
249
250
251
252
253
    def __init__(self, activation_sparsity: float, approximate: str = "none"):
        super().__init__()
        # Gelu.
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
254
255
256
257
258
259
260
        if current_platform.is_rocm() and approximate == "tanh":
            # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
            logger.warning_once(
                "[ROCm] Pytorch's native GELU with tanh approximation is currently "
                "unstable and produces garbage. Fallback to 'none' approximation."
            )
            self.approximate = "none"
Robert Shaw's avatar
Robert Shaw committed
261
262
263

        # Sparsity.
        if activation_sparsity == 0.0:
264
265
            raise ValueError("activation_sparsity is 0.0. Please use GeluAndMul.")
        target_sparsity_tensor = torch.tensor(activation_sparsity, dtype=torch.float32)
Robert Shaw's avatar
Robert Shaw committed
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        normal_dist = torch.distributions.normal.Normal(0, 1)
        self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)

    def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
        """Get % sparse percentile of the Gaussian distribution."""
        # NOTE(rob): for TP>1, we could all-gather to get the means/std.
        # But we do not do this because in expectation they are the same
        # and in practice the eval scores are good without gathering.
        mean = torch.mean(x, dim=-1, keepdim=True)
        std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
        cutoff_x = mean + std * self.std_multiplier
        return nn.functional.relu(x - cutoff_x)

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        out = self._gaussian_topk(x[..., :d])
        out = F.gelu(out, approximate=self.approximate)
        return out * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
# --8<-- [start:gelu]
@CustomOp.register("gelu")
class GELU(CustomOp):
    # --8<-- [end:gelu]

    def __init__(self):
        super().__init__()
        if current_platform.get_cpu_architecture() == CpuArchEnum.ARM and hasattr(
            torch.ops._C, "activation_lut_bf16"
        ):
            self.op = torch.ops._C.activation_lut_bf16
        else:
            self.op = None

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        return F.gelu(x, approximate="none")

    def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
        if self.op and x.dtype == torch.bfloat16 and x.is_contiguous():
            out = torch.empty_like(x)
            self.op(out, x, "gelu")
            return out
        return self.forward_native(x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


318
# --8<-- [start:gelu_and_mul]
319
@CustomOp.register("gelu_and_mul")
320
class GeluAndMul(CustomOp):
321
322
323
324
325
326
327
328
329
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

330
331
    # --8<-- [end:gelu_and_mul]

332
333
334
335
336
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
337
338
339
340
341
        if (
            current_platform.is_cuda_alike()
            or current_platform.is_cpu()
            or current_platform.is_xpu()
        ):
342
343
344
345
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
346
347
348
349
350
351
        if current_platform.is_rocm() and approximate == "tanh":
            logger.warning_once(
                "[ROCm] PyTorch's native GELU with tanh approximation is unstable "
                "with torch.compile. For native implementation, fallback to 'none' "
                "approximation. The custom kernel implementation is unaffected."
            )
352

353
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
354
        """PyTorch-native implementation equivalent to forward()."""
355
356
357
358
        # TODO: [ROCm] PyTorch's native GELU with tanh is unstable with torch.compile
        approximate = self.approximate
        if current_platform.is_rocm() and approximate == "tanh":
            approximate = "none"
359
        d = x.shape[-1] // 2
360
        return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
361

362
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
363
        d = x.shape[-1] // 2
364
        output_shape = x.shape[:-1] + (d,)
365
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
366
        self.op(out, x)
367
368
        return out

369
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
370
        return self.forward_cuda(x)
371

372
    def extra_repr(self) -> str:
373
        return f"approximate={repr(self.approximate)}"
374

375

376
# --8<-- [start:swigluoai_and_mul]
377
378
379
@CustomOp.register("swigluoai_and_mul")
class SwigluOAIAndMul(CustomOp):
    # https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
380
381
    # --8<-- [end:swigluoai_and_mul]

382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
    def __init__(self, alpha: float = 1.702, limit: float = 7.0):
        super().__init__()
        self.alpha = alpha
        self.limit = limit

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""

        gate, up = x[..., ::2], x[..., 1::2]
        gate = gate.clamp(min=None, max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        glu = gate * torch.sigmoid(gate * self.alpha)
        gated_output = (up + 1) * glu
        return gated_output

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
399
        output_shape = x.shape[:-1] + (d,)
400
401
402
403
404
405
406
407
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
        return out

    def extra_repr(self) -> str:
        return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"


csy0225's avatar
csy0225 committed
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
# --8<-- [start:swiglustep_and_mul]
@CustomOp.register("swiglustep_and_mul")
class SwigluStepAndMul(CustomOp):
    """An activation function for SwiGLU with clamping.

    Computes x -> silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit)
    where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, limit: float = 7.0):
        super().__init__()
        if limit is None:
            raise ValueError("SwigluStepAndMul requires limit to be set.")
        self.limit = limit

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        gate, up = x.chunk(2, dim=-1)
        gate = F.silu(gate)
        gate = gate.clamp(max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        return gate * up

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = x.shape[:-1] + (d,)
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        swiglustep_and_mul_triton(out, x, self.limit)
        return out

    def extra_repr(self) -> str:
        return f"limit={repr(self.limit)}"


446
# --8<-- [start:gelu_new]
447
@CustomOp.register("gelu_new")
448
class NewGELU(CustomOp):
449
450
    # --8<-- [end:gelu_new]

451
452
    def __init__(self):
        super().__init__()
453
454
455
456
457
        if (
            current_platform.is_cuda_alike()
            or current_platform.is_cpu()
            or current_platform.is_xpu()
        ):
458
459
            self.op = torch.ops._C.gelu_new

460
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
461
462
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
463
        return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
464

465
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
466
        out = torch.empty_like(x)
467
        self.op(out, x)
468
469
        return out

470
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
471
        return self.forward_cuda(x)
472

473

474
# --8<-- [start:gelu_fast]
475
@CustomOp.register("gelu_fast")
476
class FastGELU(CustomOp):
477
478
    # --8<-- [end:gelu_fast]

479
480
    def __init__(self):
        super().__init__()
481
482
483
484
485
        if (
            current_platform.is_cuda_alike()
            or current_platform.is_cpu()
            or current_platform.is_xpu()
        ):
486
487
            self.op = torch.ops._C.gelu_fast

488
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
489
        """PyTorch-native implementation equivalent to forward()."""
490
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
491

492
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
493
        out = torch.empty_like(x)
494
        self.op(out, x)
495
496
        return out

497
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
498
        return self.forward_cuda(x)
499

500

501
# --8<-- [start:quick_gelu]
502
@CustomOp.register("quick_gelu")
503
504
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
505
506
    # --8<-- [end:quick_gelu]

507
508
    def __init__(self):
        super().__init__()
509
510
511
512
513
        if (
            current_platform.is_cuda_alike()
            or current_platform.is_cpu()
            or current_platform.is_xpu()
        ):
514
515
            self.op = torch.ops._C.gelu_quick

516
517
518
519
520
521
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
522
        self.op(out, x)
523
524
        return out

525
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
526
        return self.forward_cuda(x)
527

528

529
# --8<-- [start:relu2]
530
@CustomOp.register("relu2")
531
532
533
534
535
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

536
537
    # --8<-- [end:relu2]

538
539
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
540
        return torch.square(F.relu(x))
541
542

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
543
        # TODO : implement cuda kernels
544
545
546
        return self.forward_native(x)


547
# --8<-- [start:xielu]
548
549
550
551
552
553
554
555
@CustomOp.register("xielu")
class XIELU(CustomOp):
    """
    Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
    If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
    Otherwise, we emit a single warning and use xIELU Python
    """

556
557
    # --8<-- [end:xielu]

558
559
560
561
562
563
564
565
566
567
568
    def __init__(
        self,
        alpha_p_init: float = 0.8,
        alpha_n_init: float = 0.8,
        beta: float = 0.5,
        eps: float = -1e-6,
        dtype: torch.dtype = torch.bfloat16,
        with_vector_loads: bool = False,
    ):
        super().__init__()
        self.alpha_p = nn.Parameter(
569
570
571
572
            torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
                0
            )
        )
573
574
        self.alpha_n = nn.Parameter(
            torch.log(
575
576
577
                torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
            ).unsqueeze(0)
        )
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
        self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
        self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
        self.with_vector_loads = with_vector_loads
        # Temporary until xIELU CUDA fully implemented
        self._beta_scalar = float(self.beta.detach().cpu().float().item())
        self._eps_scalar = float(self.eps.detach().cpu().float().item())

        self._xielu_cuda_obj = None
        try:
            import xielu.ops  # noqa: F401

            self._xielu_cuda_obj = torch.classes.xielu.XIELU()
            msg = "Using experimental xIELU CUDA."
            try:
                from torch._dynamo import allow_in_graph

                self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
                msg += " Enabled torch._dynamo for xIELU CUDA."
            except Exception as err:
597
598
599
600
                msg += (
                    f" Could not enable torch._dynamo for xIELU ({err}) - "
                    "this may result in slower performance."
                )
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
                self._xielu_cuda_fn = self._xielu_cuda
            logger.warning_once(msg)
        except Exception as err:
            logger.warning_once(
                "CUDA-fused xIELU not available (%s) –"
                " falling back to a Python version.\n"
                "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
                str(err),
            )

    def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
        alpha_p = nn.functional.softplus(self.alpha_p)
        alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
        return torch.where(
            x > 0,
            alpha_p * x * x + self.beta * x,
617
            (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
618
619
620
621
        )

    def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
        """Firewall function to prevent torch.compile from seeing .item()"""
622
        assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
        original_shape = x.shape
        # CUDA kernel expects 3D tensors, reshape if needed
        while x.dim() < 3:
            x = x.unsqueeze(0)
        if x.dim() > 3:
            x = x.view(-1, 1, x.size(-1))
        if original_shape != x.shape:
            logger.warning_once(
                "Warning: xIELU input tensor expects 3 dimensions"
                " but got (shape: %s). Reshaping to (shape: %s).",
                original_shape,
                x.shape,
            )
        result = self._xielu_cuda_obj.forward(
            x,
            self.alpha_p,
            self.alpha_n,
            # Temporary until xIELU CUDA fully implemented ->
            # self.{beta,eps}.item()
            self._beta_scalar,
            self._eps_scalar,
            self.with_vector_loads,
        )
        return result.view(original_shape)

648
    def forward_native(self, input: torch.Tensor) -> torch.Tensor:
649
650
651
652
653
654
655
656
657
        if self._xielu_cuda_obj is not None and input.is_cuda:
            if not torch._dynamo.is_compiling():
                return self._xielu_cuda_fn(input)
            else:
                logger.warning_once(
                    "torch._dynamo is compiling, using Python version of xIELU."
                )
        return self._xielu_python(input)

658
659
660
    def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
        return self.forward_native(input)

661

662
663
664
665
666
667
668
669
670
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
671
672
        intermediate_size: int,
        input_is_parallel: bool = True,
673
        params_dtype: torch.dtype | None = None,
674
675
676
    ):
        super().__init__()
        self.act = act_module
677
        self.input_is_parallel = input_is_parallel
678
679
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
680
            intermediate_size_per_partition = divide(intermediate_size, tp_size)
681
682
683
684
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
685
        self.scales = nn.Parameter(
686
687
            torch.empty(intermediate_size_per_partition, dtype=params_dtype)
        )
688
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
689

690
    def forward(self, x: torch.Tensor) -> torch.Tensor:
691
692
        return self.act(x) / self.scales

693
694
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
695
696
697
698
699
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
700
701
702
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

703

704
705
_ACTIVATION_REGISTRY = LazyDict(
    {
706
        "gelu": lambda: GELU(),
707
708
        "gelu_fast": lambda: FastGELU(),
        "gelu_new": lambda: NewGELU(),
709
        "gelu_pytorch_tanh": lambda: _get_gelu_pytorch_tanh(),
710
711
712
713
714
715
716
717
718
        "relu": lambda: nn.ReLU(),
        "relu2": lambda: ReLUSquaredActivation(),
        "silu": lambda: nn.SiLU(),
        "quick_gelu": lambda: QuickGELU(),
        "tanh": lambda: nn.Tanh(),
        "sigmoid": lambda: nn.Sigmoid(),
        "xielu": lambda: XIELU(),
    }
)
719
720


721
722
723
724
725
726
727
728
729
730
731
732
def _get_gelu_pytorch_tanh() -> nn.Module:
    """Get PyTorch GELU with tanh approximation, with ROCm fallback."""
    if current_platform.is_rocm():
        # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
        logger.warning_once(
            "[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
            "Falling back to GELU(approximate='none')."
        )
        return nn.GELU(approximate="none")
    return nn.GELU(approximate="tanh")


733
def get_act_fn(act_fn_name: str) -> nn.Module:
734
    """Get an activation function by name."""
735
    act_fn_name = act_fn_name.lower()
736
737
738
739
740
741
742

    if act_fn_name.startswith("torch.nn.modules."):
        activation_name = act_fn_name.split(".")[-1]
        if activation_name == "identity":
            return nn.Identity()
        act_fn_name = activation_name

743
    if act_fn_name not in _ACTIVATION_REGISTRY:
744
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
745

746
    return _ACTIVATION_REGISTRY[act_fn_name]
747
748


749
_ACTIVATION_AND_MUL_REGISTRY: LazyDict[nn.Module] = LazyDict(
750
751
752
753
    {
        "gelu": lambda: GeluAndMul(),
        "silu": lambda: SiluAndMul(),
        "geglu": lambda: GeluAndMul(),
754
        "swigluoai": lambda: SwigluOAIAndMul(),
755
756
    }
)
757
758


759
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
760
761
762
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
763
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
764

765
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]