activation.py 25 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4

5
import math
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch
import torch.nn as nn
9
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
13
14
15
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
16
from vllm.logger import init_logger
17
from vllm.model_executor.custom_op import CustomOp
18
from vllm.model_executor.utils import set_weight_attrs
19
from vllm.platforms import current_platform
csy0225's avatar
csy0225 committed
20
from vllm.triton_utils import tl, triton
21
from vllm.utils.collection_utils import LazyDict
zhuwenwen's avatar
zhuwenwen committed
22
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
23

24
25
logger = init_logger(__name__)

Woosuk Kwon's avatar
Woosuk Kwon committed
26

csy0225's avatar
csy0225 committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@triton.jit
def _swiglustep_and_mul_kernel(
    o_ptr,
    o_stride,
    x_ptr,
    x_stride,
    limit: tl.constexpr,
    d: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
) -> None:
    i = tl.program_id(axis=0).to(tl.int64)
    j = tl.program_id(axis=1)
    o_row_ptr = o_ptr + o_stride * i
    x_row_ptr = x_ptr + x_stride * i
    offsets = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < d

    gate = tl.load(x_row_ptr + offsets, mask=mask).to(tl.float32)
    up = tl.load(x_row_ptr + offsets + d, mask=mask).to(tl.float32)

    gate_silu = tl.sigmoid(gate) * gate
    gate_clamped = tl.minimum(gate_silu, limit)
    up_clamped = tl.minimum(tl.maximum(up, -limit), limit)

    result = gate_clamped * up_clamped
    result = result.to(x_ptr.dtype.element_ty)
    tl.store(o_row_ptr + offsets, result, mask=mask)


def swiglustep_and_mul_triton(
    output: torch.Tensor, input: torch.Tensor, limit: float = 7.0
):
    b, n = input.shape
    assert input.ndim == 2
    assert n % 2 == 0
    d = n // 2

    def grid(meta):
        return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))

    _swiglustep_and_mul_kernel[grid](
        output,
        output.stride(0),
        input,
        input.stride(0),
        limit=limit,
        d=d,
        BLOCK_SIZE=1024,
    )


78
# --8<-- [start:fatrelu_and_mul]
79
@CustomOp.register("fatrelu_and_mul")
80
81
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
82

83
84
85
86
87
88
89
90
91
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

92
93
    # --8<-- [end:fatrelu_and_mul]

94
    def __init__(self, threshold: float = 0.0):
95
96
        super().__init__()
        self.threshold = threshold
97
        if current_platform.is_cuda_alike():
98
            self.op = torch.ops._C.fatrelu_and_mul
99
100
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
101
102
103
104
105
106
107
108
109

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
110
        d = x.shape[-1] // 2
111
        output_shape = x.shape[:-1] + (d,)
112
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
113
        self.op(out, x, self.threshold)
114
        return out
115
116


117
# --8<-- [start:silu_and_mul]
118
@CustomOp.register("silu_and_mul")
119
class SiluAndMul(CustomOp):
120
121
    """An activation function for SwiGLU.

122
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
123

124
    Shapes:
125
126
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
127
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
128

129
130
    # --8<-- [end:silu_and_mul]

131
132
    def __init__(self, *, compile_native: bool = True):
        super().__init__(compile_native=compile_native)
133
        if current_platform.is_cuda_alike():
134
            self.op = torch.ops._C.silu_and_mul
zhuwenwen's avatar
zhuwenwen committed
135
            self.op_opt = torch.ops._C.silu_and_mul_opt
136
        elif current_platform.is_xpu():
137
            from vllm._ipex_ops import ipex_ops
138

139
            self.op = ipex_ops.silu_and_mul
140
141
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
142

143
144
    @staticmethod
    def forward_native(x: torch.Tensor) -> torch.Tensor:
145
        """PyTorch-native implementation equivalent to forward()."""
zhuwenwen's avatar
zhuwenwen committed
146
147
148
149
        # if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
        #     return self.forward_cuda(x)
        # else:
        d = x.shape[-1] // 2
150
151
        return F.silu(x[..., :d]) * x[..., d:]

152
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
153
        d = x.shape[-1] // 2
154
        output_shape = x.shape[:-1] + (d,)
155
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
156
        if envs.VLLM_USE_OPT_OP:
zhuwenwen's avatar
zhuwenwen committed
157
            self.op_opt(out, x)
zhuwenwen's avatar
zhuwenwen committed
158
        else:
zhuwenwen's avatar
zhuwenwen committed
159
            self.op(out, x) 
Woosuk Kwon's avatar
Woosuk Kwon committed
160
        return out
161

162
163
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
164
        output_shape = x.shape[:-1] + (d,)
165
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
166
        self.op(out, x)
167
168
        return out

169

170
# --8<-- [start:mul_and_silu]
171
172
173
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.
174

175
176
177
178
179
180
181
    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

182
183
    # --8<-- [end:mul_and_silu]

184
185
    def __init__(self):
        super().__init__()
186
        if current_platform.is_cuda_alike():
187
188
189
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
190

191
            self.op = ipex_ops.silu_and_mul
192
193
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
194
195
196
197
198
199
200

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
201
        d = x.shape[-1] // 2
202
        output_shape = x.shape[:-1] + (d,)
203
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
204
        self.op(out, x)
205
206
        return out

207
208
209
    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

210

211
# --8<-- [start:gelu_and_mul_sparse]
Robert Shaw's avatar
Robert Shaw committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
    """An activation function for GeluAndMulSparse.
    This activation function is used in Gemma3n. It computes:
        up_proj = self.up_proj(x)
        gate_proj = self.gate_proj(x)
        gate_proj = self._gaussian_topk(gate_proj) # sparsity
        activations = self.act_fn(gate_proj) # gelu
        down_proj = self.down_proj(activations * up_proj)
    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

226
227
    # --8<-- [end:gelu_and_mul_sparse]

Robert Shaw's avatar
Robert Shaw committed
228
229
230
231
232
233
    def __init__(self, activation_sparsity: float, approximate: str = "none"):
        super().__init__()
        # Gelu.
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
234
235
236
237
238
239
240
        if current_platform.is_rocm() and approximate == "tanh":
            # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
            logger.warning_once(
                "[ROCm] Pytorch's native GELU with tanh approximation is currently "
                "unstable and produces garbage. Fallback to 'none' approximation."
            )
            self.approximate = "none"
Robert Shaw's avatar
Robert Shaw committed
241
242
243

        # Sparsity.
        if activation_sparsity == 0.0:
244
245
            raise ValueError("activation_sparsity is 0.0. Please use GeluAndMul.")
        target_sparsity_tensor = torch.tensor(activation_sparsity, dtype=torch.float32)
Robert Shaw's avatar
Robert Shaw committed
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
        normal_dist = torch.distributions.normal.Normal(0, 1)
        self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)

    def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
        """Get % sparse percentile of the Gaussian distribution."""
        # NOTE(rob): for TP>1, we could all-gather to get the means/std.
        # But we do not do this because in expectation they are the same
        # and in practice the eval scores are good without gathering.
        mean = torch.mean(x, dim=-1, keepdim=True)
        std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
        cutoff_x = mean + std * self.std_multiplier
        return nn.functional.relu(x - cutoff_x)

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        out = self._gaussian_topk(x[..., :d])
        out = F.gelu(out, approximate=self.approximate)
        return out * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


270
# --8<-- [start:gelu_and_mul]
271
@CustomOp.register("gelu_and_mul")
272
class GeluAndMul(CustomOp):
273
274
275
276
277
278
279
280
281
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

282
283
    # --8<-- [end:gelu_and_mul]

284
285
286
287
288
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
289
290
291
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
zhuwenwen's avatar
zhuwenwen committed
292
                self.op_opt = torch.ops._C.gelu_and_mul_opt
293
294
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
zhuwenwen's avatar
zhuwenwen committed
295
                self.op_opt = torch.ops._C.gelu_tanh_and_mul_opt
296
297
298
299
300
301
        if current_platform.is_rocm() and approximate == "tanh":
            logger.warning_once(
                "[ROCm] PyTorch's native GELU with tanh approximation is unstable "
                "with torch.compile. For native implementation, fallback to 'none' "
                "approximation. The custom kernel implementation is unaffected."
            )
302
303
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
304

305
306
307
308
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
309

310
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
311
        """PyTorch-native implementation equivalent to forward()."""
312
313
314
315
        # TODO: [ROCm] PyTorch's native GELU with tanh is unstable with torch.compile
        approximate = self.approximate
        if current_platform.is_rocm() and approximate == "tanh":
            approximate = "none"
316
        d = x.shape[-1] // 2
317
        return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
318

319
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
320
        d = x.shape[-1] // 2
321
        output_shape = x.shape[:-1] + (d,)
322
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
323
324
325
326
        if envs.VLLM_USE_OPT_OP: 
            self.op_opt(out, x)
        else:
            self.op(out, x)
327
328
        return out

329
330
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
331
        output_shape = x.shape[:-1] + (d,)
332
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
333
        self.op(out, x)
334
335
        return out

336
    def extra_repr(self) -> str:
337
        return f"approximate={repr(self.approximate)}"
338

339

340
# --8<-- [start:swigluoai_and_mul]
341
342
343
@CustomOp.register("swigluoai_and_mul")
class SwigluOAIAndMul(CustomOp):
    # https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
344
345
    # --8<-- [end:swigluoai_and_mul]

346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
    def __init__(self, alpha: float = 1.702, limit: float = 7.0):
        super().__init__()
        self.alpha = alpha
        self.limit = limit

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""

        gate, up = x[..., ::2], x[..., 1::2]
        gate = gate.clamp(min=None, max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        glu = gate * torch.sigmoid(gate * self.alpha)
        gated_output = (up + 1) * glu
        return gated_output

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
363
        output_shape = x.shape[:-1] + (d,)
364
365
366
367
368
369
370
371
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
        return out

    def extra_repr(self) -> str:
        return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"


csy0225's avatar
csy0225 committed
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
# --8<-- [start:swiglustep_and_mul]
@CustomOp.register("swiglustep_and_mul")
class SwigluStepAndMul(CustomOp):
    """An activation function for SwiGLU with clamping.

    Computes x -> silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit)
    where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, limit: float = 7.0):
        super().__init__()
        if limit is None:
            raise ValueError("SwigluStepAndMul requires limit to be set.")
        self.limit = limit

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        gate, up = x.chunk(2, dim=-1)
        gate = F.silu(gate)
        gate = gate.clamp(max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        return gate * up

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = x.shape[:-1] + (d,)
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        swiglustep_and_mul_triton(out, x, self.limit)
        return out

    def extra_repr(self) -> str:
        return f"limit={repr(self.limit)}"


410
# --8<-- [start:gelu_new]
411
@CustomOp.register("gelu_new")
412
class NewGELU(CustomOp):
413
414
    # --8<-- [end:gelu_new]

415
416
417
418
419
420
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
421

422
423
            self.op = ipex_ops.gelu_new

424
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
425
426
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
427
        return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
428

429
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
430
        out = torch.empty_like(x)
431
        self.op(out, x)
432
433
        return out

434
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
435
        return self.op(x)
436

437

438
# --8<-- [start:gelu_fast]
439
@CustomOp.register("gelu_fast")
440
class FastGELU(CustomOp):
441
442
    # --8<-- [end:gelu_fast]

443
444
445
446
447
448
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
449

450
451
            self.op = ipex_ops.gelu_fast

452
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
453
        """PyTorch-native implementation equivalent to forward()."""
454
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
455

456
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
457
        out = torch.empty_like(x)
458
        self.op(out, x)
459
460
        return out

461
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
462
        return self.op(x)
463

464

465
# --8<-- [start:quick_gelu]
466
@CustomOp.register("quick_gelu")
467
468
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
469
470
    # --8<-- [end:quick_gelu]

471
472
473
474
475
476
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
477

478
479
            self.op = ipex_ops.gelu_quick

480
481
482
483
484
485
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
486
        self.op(out, x)
487
488
        return out

489
490
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
491
        self.op(out, x)
492
493
        return out

494
495
496
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

497

498
# --8<-- [start:relu2]
499
@CustomOp.register("relu2")
500
501
502
503
504
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

505
506
    # --8<-- [end:relu2]

507
508
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
509
        return torch.square(F.relu(x))
510
511

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
512
        # TODO : implement cuda kernels
513
514
515
        return self.forward_native(x)


516
# --8<-- [start:xielu]
517
518
519
520
521
522
523
524
@CustomOp.register("xielu")
class XIELU(CustomOp):
    """
    Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
    If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
    Otherwise, we emit a single warning and use xIELU Python
    """

525
526
    # --8<-- [end:xielu]

527
528
529
530
531
532
533
534
535
536
537
    def __init__(
        self,
        alpha_p_init: float = 0.8,
        alpha_n_init: float = 0.8,
        beta: float = 0.5,
        eps: float = -1e-6,
        dtype: torch.dtype = torch.bfloat16,
        with_vector_loads: bool = False,
    ):
        super().__init__()
        self.alpha_p = nn.Parameter(
538
539
540
541
            torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
                0
            )
        )
542
543
        self.alpha_n = nn.Parameter(
            torch.log(
544
545
546
                torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
            ).unsqueeze(0)
        )
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
        self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
        self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
        self.with_vector_loads = with_vector_loads
        # Temporary until xIELU CUDA fully implemented
        self._beta_scalar = float(self.beta.detach().cpu().float().item())
        self._eps_scalar = float(self.eps.detach().cpu().float().item())

        self._xielu_cuda_obj = None
        try:
            import xielu.ops  # noqa: F401

            self._xielu_cuda_obj = torch.classes.xielu.XIELU()
            msg = "Using experimental xIELU CUDA."
            try:
                from torch._dynamo import allow_in_graph

                self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
                msg += " Enabled torch._dynamo for xIELU CUDA."
            except Exception as err:
566
567
568
569
                msg += (
                    f" Could not enable torch._dynamo for xIELU ({err}) - "
                    "this may result in slower performance."
                )
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
                self._xielu_cuda_fn = self._xielu_cuda
            logger.warning_once(msg)
        except Exception as err:
            logger.warning_once(
                "CUDA-fused xIELU not available (%s) –"
                " falling back to a Python version.\n"
                "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
                str(err),
            )

    def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
        alpha_p = nn.functional.softplus(self.alpha_p)
        alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
        return torch.where(
            x > 0,
            alpha_p * x * x + self.beta * x,
586
            (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
587
588
589
590
        )

    def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
        """Firewall function to prevent torch.compile from seeing .item()"""
591
        assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
        original_shape = x.shape
        # CUDA kernel expects 3D tensors, reshape if needed
        while x.dim() < 3:
            x = x.unsqueeze(0)
        if x.dim() > 3:
            x = x.view(-1, 1, x.size(-1))
        if original_shape != x.shape:
            logger.warning_once(
                "Warning: xIELU input tensor expects 3 dimensions"
                " but got (shape: %s). Reshaping to (shape: %s).",
                original_shape,
                x.shape,
            )
        result = self._xielu_cuda_obj.forward(
            x,
            self.alpha_p,
            self.alpha_n,
            # Temporary until xIELU CUDA fully implemented ->
            # self.{beta,eps}.item()
            self._beta_scalar,
            self._eps_scalar,
            self.with_vector_loads,
        )
        return result.view(original_shape)

617
    def forward_native(self, input: torch.Tensor) -> torch.Tensor:
618
619
620
621
622
623
624
625
626
        if self._xielu_cuda_obj is not None and input.is_cuda:
            if not torch._dynamo.is_compiling():
                return self._xielu_cuda_fn(input)
            else:
                logger.warning_once(
                    "torch._dynamo is compiling, using Python version of xIELU."
                )
        return self._xielu_python(input)

627
628
629
    def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
        return self.forward_native(input)

630

631
632
633
634
635
636
637
638
639
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
640
641
        intermediate_size: int,
        input_is_parallel: bool = True,
642
        params_dtype: torch.dtype | None = None,
643
644
645
    ):
        super().__init__()
        self.act = act_module
646
        self.input_is_parallel = input_is_parallel
647
648
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
649
            intermediate_size_per_partition = divide(intermediate_size, tp_size)
650
651
652
653
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
654
        self.scales = nn.Parameter(
655
656
            torch.empty(intermediate_size_per_partition, dtype=params_dtype)
        )
657
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
658

659
    def forward(self, x: torch.Tensor) -> torch.Tensor:
660
661
        return self.act(x) / self.scales

662
663
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
664
665
666
667
668
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
669
670
671
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

672

673
674
675
676
677
_ACTIVATION_REGISTRY = LazyDict(
    {
        "gelu": lambda: nn.GELU(),
        "gelu_fast": lambda: FastGELU(),
        "gelu_new": lambda: NewGELU(),
678
679
680
681
682
683
684
685
686
687
        "gelu_pytorch_tanh": lambda: (
            # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
            logger.warning_once(
                "[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
                "Falling back to GELU(approximate='none')."
            ),
            nn.GELU(approximate="none"),
        )[1]
        if current_platform.is_rocm()
        else nn.GELU(approximate="tanh"),
688
689
690
691
692
693
694
695
696
        "relu": lambda: nn.ReLU(),
        "relu2": lambda: ReLUSquaredActivation(),
        "silu": lambda: nn.SiLU(),
        "quick_gelu": lambda: QuickGELU(),
        "tanh": lambda: nn.Tanh(),
        "sigmoid": lambda: nn.Sigmoid(),
        "xielu": lambda: XIELU(),
    }
)
697
698


699
def get_act_fn(act_fn_name: str) -> nn.Module:
700
    """Get an activation function by name."""
701
    act_fn_name = act_fn_name.lower()
702
703
704
705
706
707
708

    if act_fn_name.startswith("torch.nn.modules."):
        activation_name = act_fn_name.split(".")[-1]
        if activation_name == "identity":
            return nn.Identity()
        act_fn_name = activation_name

709
    if act_fn_name not in _ACTIVATION_REGISTRY:
710
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
711

712
    return _ACTIVATION_REGISTRY[act_fn_name]
713
714


715
716
717
718
719
720
721
722
_ACTIVATION_AND_MUL_REGISTRY = LazyDict(
    {
        "gelu": lambda: GeluAndMul(),
        "silu": lambda: SiluAndMul(),
        "geglu": lambda: GeluAndMul(),
        "swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
    }
)
723
724


725
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
726
727
728
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
729
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
730

731
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]