activation.py 25.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4

5
import math
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch
import torch.nn as nn
9
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
13
14
15
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
16
from vllm.logger import init_logger
17
from vllm.model_executor.custom_op import CustomOp
18
from vllm.model_executor.utils import set_weight_attrs
19
from vllm.platforms import current_platform
csy0225's avatar
csy0225 committed
20
from vllm.triton_utils import tl, triton
21
from vllm.utils.collection_utils import LazyDict
zhuwenwen's avatar
zhuwenwen committed
22
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
23

24
25
logger = init_logger(__name__)

Woosuk Kwon's avatar
Woosuk Kwon committed
26

csy0225's avatar
csy0225 committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@triton.jit
def _swiglustep_and_mul_kernel(
    o_ptr,
    o_stride,
    x_ptr,
    x_stride,
    limit: tl.constexpr,
    d: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
) -> None:
    i = tl.program_id(axis=0).to(tl.int64)
    j = tl.program_id(axis=1)
    o_row_ptr = o_ptr + o_stride * i
    x_row_ptr = x_ptr + x_stride * i
    offsets = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < d

    gate = tl.load(x_row_ptr + offsets, mask=mask).to(tl.float32)
    up = tl.load(x_row_ptr + offsets + d, mask=mask).to(tl.float32)

    gate_silu = tl.sigmoid(gate) * gate
    gate_clamped = tl.minimum(gate_silu, limit)
    up_clamped = tl.minimum(tl.maximum(up, -limit), limit)

    result = gate_clamped * up_clamped
    result = result.to(x_ptr.dtype.element_ty)
    tl.store(o_row_ptr + offsets, result, mask=mask)


def swiglustep_and_mul_triton(
    output: torch.Tensor, input: torch.Tensor, limit: float = 7.0
):
    b, n = input.shape
    assert input.ndim == 2
    assert n % 2 == 0
    d = n // 2

    def grid(meta):
        return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))

    _swiglustep_and_mul_kernel[grid](
        output,
        output.stride(0),
        input,
        input.stride(0),
        limit=limit,
        d=d,
        BLOCK_SIZE=1024,
    )


78
# --8<-- [start:fatrelu_and_mul]
79
@CustomOp.register("fatrelu_and_mul")
80
81
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
82

83
84
85
86
87
88
89
90
91
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

92
93
    # --8<-- [end:fatrelu_and_mul]

94
    def __init__(self, threshold: float = 0.0):
95
96
        super().__init__()
        self.threshold = threshold
97
        if current_platform.is_cuda_alike():
98
            self.op = torch.ops._C.fatrelu_and_mul
99
100
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
101
102
103
104
105
106
107
108
109

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
110
        d = x.shape[-1] // 2
111
        output_shape = x.shape[:-1] + (d,)
112
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
113
        self.op(out, x, self.threshold)
114
        return out
115
116


117
# --8<-- [start:silu_and_mul]
118
@CustomOp.register("silu_and_mul")
119
class SiluAndMul(CustomOp):
120
121
    """An activation function for SwiGLU.

122
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
123

124
    Shapes:
125
126
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
127
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
128

129
130
    # --8<-- [end:silu_and_mul]

131
132
    def __init__(self, *, compile_native: bool = True):
        super().__init__(compile_native=compile_native)
133
        if current_platform.is_cuda_alike():
134
            self.op = torch.ops._C.silu_and_mul
zhuwenwen's avatar
zhuwenwen committed
135
            self.op_opt = torch.ops._C.silu_and_mul_opt
136
        elif current_platform.is_xpu():
137
            from vllm._ipex_ops import ipex_ops
138

139
            self.op = ipex_ops.silu_and_mul
140
141
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
142

143
144
    @staticmethod
    def forward_native(x: torch.Tensor) -> torch.Tensor:
145
        """PyTorch-native implementation equivalent to forward()."""
zhuwenwen's avatar
zhuwenwen committed
146
147
148
149
        # if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
        #     return self.forward_cuda(x)
        # else:
        d = x.shape[-1] // 2
150
151
        return F.silu(x[..., :d]) * x[..., d:]

152
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
zhuwenwen's avatar
zhuwenwen committed
153
        if envs.VLLM_USE_OPT_OP:
154
155
156
            from vllm import _custom_ops as ops

            return ops.silu_and_mul_opt_lightop(x)
zhuwenwen's avatar
zhuwenwen committed
157
        else:
158
159
160
161
            d = x.shape[-1] // 2
            output_shape = x.shape[:-1] + (d,)
            out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
            self.op(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
162
        return out
163

164
165
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
166
        output_shape = x.shape[:-1] + (d,)
167
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
168
        self.op(out, x)
169
170
        return out

171

172
# --8<-- [start:mul_and_silu]
173
174
175
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.
176

177
178
179
180
181
182
183
    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

184
185
    # --8<-- [end:mul_and_silu]

186
187
    def __init__(self):
        super().__init__()
188
        if current_platform.is_cuda_alike():
189
190
191
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
192

193
            self.op = ipex_ops.silu_and_mul
194
195
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
196
197
198
199
200
201
202

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
203
        d = x.shape[-1] // 2
204
        output_shape = x.shape[:-1] + (d,)
205
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
206
        self.op(out, x)
207
208
        return out

209
210
211
    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

212

213
# --8<-- [start:gelu_and_mul_sparse]
Robert Shaw's avatar
Robert Shaw committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
    """An activation function for GeluAndMulSparse.
    This activation function is used in Gemma3n. It computes:
        up_proj = self.up_proj(x)
        gate_proj = self.gate_proj(x)
        gate_proj = self._gaussian_topk(gate_proj) # sparsity
        activations = self.act_fn(gate_proj) # gelu
        down_proj = self.down_proj(activations * up_proj)
    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

228
229
    # --8<-- [end:gelu_and_mul_sparse]

Robert Shaw's avatar
Robert Shaw committed
230
231
232
233
234
235
    def __init__(self, activation_sparsity: float, approximate: str = "none"):
        super().__init__()
        # Gelu.
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
236
237
238
239
240
241
242
        if current_platform.is_rocm() and approximate == "tanh":
            # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
            logger.warning_once(
                "[ROCm] Pytorch's native GELU with tanh approximation is currently "
                "unstable and produces garbage. Fallback to 'none' approximation."
            )
            self.approximate = "none"
Robert Shaw's avatar
Robert Shaw committed
243
244
245

        # Sparsity.
        if activation_sparsity == 0.0:
246
247
            raise ValueError("activation_sparsity is 0.0. Please use GeluAndMul.")
        target_sparsity_tensor = torch.tensor(activation_sparsity, dtype=torch.float32)
Robert Shaw's avatar
Robert Shaw committed
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        normal_dist = torch.distributions.normal.Normal(0, 1)
        self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)

    def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
        """Get % sparse percentile of the Gaussian distribution."""
        # NOTE(rob): for TP>1, we could all-gather to get the means/std.
        # But we do not do this because in expectation they are the same
        # and in practice the eval scores are good without gathering.
        mean = torch.mean(x, dim=-1, keepdim=True)
        std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
        cutoff_x = mean + std * self.std_multiplier
        return nn.functional.relu(x - cutoff_x)

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        out = self._gaussian_topk(x[..., :d])
        out = F.gelu(out, approximate=self.approximate)
        return out * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


272
# --8<-- [start:gelu_and_mul]
273
@CustomOp.register("gelu_and_mul")
274
class GeluAndMul(CustomOp):
275
276
277
278
279
280
281
282
283
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

284
285
    # --8<-- [end:gelu_and_mul]

286
287
288
289
290
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
291
292
293
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
zhuwenwen's avatar
zhuwenwen committed
294
                self.op_opt = torch.ops._C.gelu_and_mul_opt
295
296
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
zhuwenwen's avatar
zhuwenwen committed
297
                self.op_opt = torch.ops._C.gelu_tanh_and_mul_opt
298
299
300
301
302
303
        if current_platform.is_rocm() and approximate == "tanh":
            logger.warning_once(
                "[ROCm] PyTorch's native GELU with tanh approximation is unstable "
                "with torch.compile. For native implementation, fallback to 'none' "
                "approximation. The custom kernel implementation is unaffected."
            )
304
305
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
306

307
308
309
310
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
311

312
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
313
        """PyTorch-native implementation equivalent to forward()."""
314
315
316
317
        # TODO: [ROCm] PyTorch's native GELU with tanh is unstable with torch.compile
        approximate = self.approximate
        if current_platform.is_rocm() and approximate == "tanh":
            approximate = "none"
318
        d = x.shape[-1] // 2
319
        return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
320

321
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
322
        d = x.shape[-1] // 2
323
        output_shape = x.shape[:-1] + (d,)
324
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
325
326
327
328
        if envs.VLLM_USE_OPT_OP: 
            self.op_opt(out, x)
        else:
            self.op(out, x)
329
330
        return out

331
332
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
333
        output_shape = x.shape[:-1] + (d,)
334
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
335
        self.op(out, x)
336
337
        return out

338
    def extra_repr(self) -> str:
339
        return f"approximate={repr(self.approximate)}"
340

341

342
# --8<-- [start:swigluoai_and_mul]
343
344
345
@CustomOp.register("swigluoai_and_mul")
class SwigluOAIAndMul(CustomOp):
    # https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
346
347
    # --8<-- [end:swigluoai_and_mul]

348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
    def __init__(self, alpha: float = 1.702, limit: float = 7.0):
        super().__init__()
        self.alpha = alpha
        self.limit = limit

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""

        gate, up = x[..., ::2], x[..., 1::2]
        gate = gate.clamp(min=None, max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        glu = gate * torch.sigmoid(gate * self.alpha)
        gated_output = (up + 1) * glu
        return gated_output

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
365
        output_shape = x.shape[:-1] + (d,)
366
367
368
369
370
371
372
373
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
        return out

    def extra_repr(self) -> str:
        return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"


csy0225's avatar
csy0225 committed
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
# --8<-- [start:swiglustep_and_mul]
@CustomOp.register("swiglustep_and_mul")
class SwigluStepAndMul(CustomOp):
    """An activation function for SwiGLU with clamping.

    Computes x -> silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit)
    where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, limit: float = 7.0):
        super().__init__()
        if limit is None:
            raise ValueError("SwigluStepAndMul requires limit to be set.")
        self.limit = limit

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        gate, up = x.chunk(2, dim=-1)
        gate = F.silu(gate)
        gate = gate.clamp(max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        return gate * up

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = x.shape[:-1] + (d,)
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        swiglustep_and_mul_triton(out, x, self.limit)
        return out

    def extra_repr(self) -> str:
        return f"limit={repr(self.limit)}"


412
# --8<-- [start:gelu_new]
413
@CustomOp.register("gelu_new")
414
class NewGELU(CustomOp):
415
416
    # --8<-- [end:gelu_new]

417
418
419
420
421
422
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
423

424
425
            self.op = ipex_ops.gelu_new

426
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
427
428
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
429
        return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
430

431
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
432
        out = torch.empty_like(x)
433
        self.op(out, x)
434
435
        return out

436
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
437
        return self.op(x)
438

439

440
# --8<-- [start:gelu_fast]
441
@CustomOp.register("gelu_fast")
442
class FastGELU(CustomOp):
443
444
    # --8<-- [end:gelu_fast]

445
446
447
448
449
450
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
451

452
453
            self.op = ipex_ops.gelu_fast

454
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
455
        """PyTorch-native implementation equivalent to forward()."""
456
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
457

458
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
459
        out = torch.empty_like(x)
460
        self.op(out, x)
461
462
        return out

463
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
464
        return self.op(x)
465

466

467
# --8<-- [start:quick_gelu]
468
@CustomOp.register("quick_gelu")
469
470
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
471
472
    # --8<-- [end:quick_gelu]

473
474
475
476
477
478
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
479

480
481
            self.op = ipex_ops.gelu_quick

482
483
484
485
486
487
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
488
        self.op(out, x)
489
490
        return out

491
492
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
493
        self.op(out, x)
494
495
        return out

496
497
498
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

499

500
# --8<-- [start:relu2]
501
@CustomOp.register("relu2")
502
503
504
505
506
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

507
508
    # --8<-- [end:relu2]

509
510
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
511
        return torch.square(F.relu(x))
512
513

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
514
        # TODO : implement cuda kernels
515
516
517
        return self.forward_native(x)


518
# --8<-- [start:xielu]
519
520
521
522
523
524
525
526
@CustomOp.register("xielu")
class XIELU(CustomOp):
    """
    Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
    If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
    Otherwise, we emit a single warning and use xIELU Python
    """

527
528
    # --8<-- [end:xielu]

529
530
531
532
533
534
535
536
537
538
539
    def __init__(
        self,
        alpha_p_init: float = 0.8,
        alpha_n_init: float = 0.8,
        beta: float = 0.5,
        eps: float = -1e-6,
        dtype: torch.dtype = torch.bfloat16,
        with_vector_loads: bool = False,
    ):
        super().__init__()
        self.alpha_p = nn.Parameter(
540
541
542
543
            torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
                0
            )
        )
544
545
        self.alpha_n = nn.Parameter(
            torch.log(
546
547
548
                torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
            ).unsqueeze(0)
        )
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
        self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
        self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
        self.with_vector_loads = with_vector_loads
        # Temporary until xIELU CUDA fully implemented
        self._beta_scalar = float(self.beta.detach().cpu().float().item())
        self._eps_scalar = float(self.eps.detach().cpu().float().item())

        self._xielu_cuda_obj = None
        try:
            import xielu.ops  # noqa: F401

            self._xielu_cuda_obj = torch.classes.xielu.XIELU()
            msg = "Using experimental xIELU CUDA."
            try:
                from torch._dynamo import allow_in_graph

                self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
                msg += " Enabled torch._dynamo for xIELU CUDA."
            except Exception as err:
568
569
570
571
                msg += (
                    f" Could not enable torch._dynamo for xIELU ({err}) - "
                    "this may result in slower performance."
                )
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
                self._xielu_cuda_fn = self._xielu_cuda
            logger.warning_once(msg)
        except Exception as err:
            logger.warning_once(
                "CUDA-fused xIELU not available (%s) –"
                " falling back to a Python version.\n"
                "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
                str(err),
            )

    def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
        alpha_p = nn.functional.softplus(self.alpha_p)
        alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
        return torch.where(
            x > 0,
            alpha_p * x * x + self.beta * x,
588
            (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
589
590
591
592
        )

    def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
        """Firewall function to prevent torch.compile from seeing .item()"""
593
        assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
        original_shape = x.shape
        # CUDA kernel expects 3D tensors, reshape if needed
        while x.dim() < 3:
            x = x.unsqueeze(0)
        if x.dim() > 3:
            x = x.view(-1, 1, x.size(-1))
        if original_shape != x.shape:
            logger.warning_once(
                "Warning: xIELU input tensor expects 3 dimensions"
                " but got (shape: %s). Reshaping to (shape: %s).",
                original_shape,
                x.shape,
            )
        result = self._xielu_cuda_obj.forward(
            x,
            self.alpha_p,
            self.alpha_n,
            # Temporary until xIELU CUDA fully implemented ->
            # self.{beta,eps}.item()
            self._beta_scalar,
            self._eps_scalar,
            self.with_vector_loads,
        )
        return result.view(original_shape)

619
    def forward_native(self, input: torch.Tensor) -> torch.Tensor:
620
621
622
623
624
625
626
627
628
        if self._xielu_cuda_obj is not None and input.is_cuda:
            if not torch._dynamo.is_compiling():
                return self._xielu_cuda_fn(input)
            else:
                logger.warning_once(
                    "torch._dynamo is compiling, using Python version of xIELU."
                )
        return self._xielu_python(input)

629
630
631
    def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
        return self.forward_native(input)

632

633
634
635
636
637
638
639
640
641
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
642
643
        intermediate_size: int,
        input_is_parallel: bool = True,
644
        params_dtype: torch.dtype | None = None,
645
646
647
    ):
        super().__init__()
        self.act = act_module
648
        self.input_is_parallel = input_is_parallel
649
650
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
651
            intermediate_size_per_partition = divide(intermediate_size, tp_size)
652
653
654
655
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
656
        self.scales = nn.Parameter(
657
658
            torch.empty(intermediate_size_per_partition, dtype=params_dtype)
        )
659
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
660

661
    def forward(self, x: torch.Tensor) -> torch.Tensor:
662
663
        return self.act(x) / self.scales

664
665
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
666
667
668
669
670
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
671
672
673
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

674

675
676
677
678
679
_ACTIVATION_REGISTRY = LazyDict(
    {
        "gelu": lambda: nn.GELU(),
        "gelu_fast": lambda: FastGELU(),
        "gelu_new": lambda: NewGELU(),
680
681
682
683
684
685
686
687
688
689
        "gelu_pytorch_tanh": lambda: (
            # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
            logger.warning_once(
                "[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
                "Falling back to GELU(approximate='none')."
            ),
            nn.GELU(approximate="none"),
        )[1]
        if current_platform.is_rocm()
        else nn.GELU(approximate="tanh"),
690
691
692
693
694
695
696
697
698
        "relu": lambda: nn.ReLU(),
        "relu2": lambda: ReLUSquaredActivation(),
        "silu": lambda: nn.SiLU(),
        "quick_gelu": lambda: QuickGELU(),
        "tanh": lambda: nn.Tanh(),
        "sigmoid": lambda: nn.Sigmoid(),
        "xielu": lambda: XIELU(),
    }
)
699
700


701
def get_act_fn(act_fn_name: str) -> nn.Module:
702
    """Get an activation function by name."""
703
    act_fn_name = act_fn_name.lower()
704
705
706
707
708
709
710

    if act_fn_name.startswith("torch.nn.modules."):
        activation_name = act_fn_name.split(".")[-1]
        if activation_name == "identity":
            return nn.Identity()
        act_fn_name = activation_name

711
    if act_fn_name not in _ACTIVATION_REGISTRY:
712
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
713

714
    return _ACTIVATION_REGISTRY[act_fn_name]
715
716


717
718
719
720
721
722
723
724
_ACTIVATION_AND_MUL_REGISTRY = LazyDict(
    {
        "gelu": lambda: GeluAndMul(),
        "silu": lambda: SiluAndMul(),
        "geglu": lambda: GeluAndMul(),
        "swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
    }
)
725
726


727
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
728
729
730
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
731
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
732

733
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]