activation.py 24.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4

5
import math
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch
import torch.nn as nn
9
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
13
14
15
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
16
from vllm.logger import init_logger
17
from vllm.model_executor.custom_op import CustomOp
18
from vllm.model_executor.utils import set_weight_attrs
19
from vllm.platforms import CpuArchEnum, current_platform
csy0225's avatar
csy0225 committed
20
from vllm.triton_utils import tl, triton
21
from vllm.utils.collection_utils import LazyDict
Woosuk Kwon's avatar
Woosuk Kwon committed
22

23
24
logger = init_logger(__name__)

Woosuk Kwon's avatar
Woosuk Kwon committed
25

csy0225's avatar
csy0225 committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@triton.jit
def _swiglustep_and_mul_kernel(
    o_ptr,
    o_stride,
    x_ptr,
    x_stride,
    limit: tl.constexpr,
    d: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
) -> None:
    i = tl.program_id(axis=0).to(tl.int64)
    j = tl.program_id(axis=1)
    o_row_ptr = o_ptr + o_stride * i
    x_row_ptr = x_ptr + x_stride * i
    offsets = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < d

    gate = tl.load(x_row_ptr + offsets, mask=mask).to(tl.float32)
    up = tl.load(x_row_ptr + offsets + d, mask=mask).to(tl.float32)

    gate_silu = tl.sigmoid(gate) * gate
    gate_clamped = tl.minimum(gate_silu, limit)
    up_clamped = tl.minimum(tl.maximum(up, -limit), limit)

    result = gate_clamped * up_clamped
    result = result.to(x_ptr.dtype.element_ty)
    tl.store(o_row_ptr + offsets, result, mask=mask)


def swiglustep_and_mul_triton(
    output: torch.Tensor, input: torch.Tensor, limit: float = 7.0
):
    b, n = input.shape
    assert input.ndim == 2
    assert n % 2 == 0
    d = n // 2

    def grid(meta):
        return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))

    _swiglustep_and_mul_kernel[grid](
        output,
        output.stride(0),
        input,
        input.stride(0),
        limit=limit,
        d=d,
        BLOCK_SIZE=1024,
    )


77
# --8<-- [start:fatrelu_and_mul]
78
@CustomOp.register("fatrelu_and_mul")
79
80
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
81

82
83
84
85
86
87
88
89
90
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

91
92
    # --8<-- [end:fatrelu_and_mul]

93
    def __init__(self, threshold: float = 0.0):
94
95
        super().__init__()
        self.threshold = threshold
96
        if current_platform.is_cuda_alike():
97
            self.op = torch.ops._C.fatrelu_and_mul
98
99
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
100
101
102
103
104
105
106
107
108

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
109
        d = x.shape[-1] // 2
110
        output_shape = x.shape[:-1] + (d,)
111
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
112
        self.op(out, x, self.threshold)
113
        return out
114
115


116
# --8<-- [start:silu_and_mul]
117
@CustomOp.register("silu_and_mul")
118
class SiluAndMul(CustomOp):
119
120
    """An activation function for SwiGLU.

121
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
122

123
    Shapes:
124
125
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
126
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
127

128
129
    # --8<-- [end:silu_and_mul]

130
131
    def __init__(self, *, compile_native: bool = True):
        super().__init__(compile_native=compile_native)
132
        if current_platform.is_cuda_alike() or current_platform.is_xpu():
133
            self.op = torch.ops._C.silu_and_mul
134
135
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
136

137
138
    @staticmethod
    def forward_native(x: torch.Tensor) -> torch.Tensor:
139
140
141
142
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

143
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
144
        d = x.shape[-1] // 2
145
        output_shape = x.shape[:-1] + (d,)
146
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
147
        self.op(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
148
        return out
149

150
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
151
        return self.forward_cuda(x)
152

153

154
# --8<-- [start:mul_and_silu]
155
156
157
158
159
160
161
162
163
164
165
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.

    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

166
167
    # --8<-- [end:mul_and_silu]

168
169
    def __init__(self):
        super().__init__()
170
        if current_platform.is_cuda_alike() or current_platform.is_xpu():
171
            self.op = torch.ops._C.mul_and_silu
172
173
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
174
175
176
177
178
179
180
181

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
182
        output_shape = x.shape[:-1] + (d,)
183
184
185
186
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        self.op(out, x)
        return out

187
188
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_cuda(x)
189
190


191
# --8<-- [start:gelu_and_mul_sparse]
Robert Shaw's avatar
Robert Shaw committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
    """An activation function for GeluAndMulSparse.
    This activation function is used in Gemma3n. It computes:
        up_proj = self.up_proj(x)
        gate_proj = self.gate_proj(x)
        gate_proj = self._gaussian_topk(gate_proj) # sparsity
        activations = self.act_fn(gate_proj) # gelu
        down_proj = self.down_proj(activations * up_proj)
    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

206
207
    # --8<-- [end:gelu_and_mul_sparse]

Robert Shaw's avatar
Robert Shaw committed
208
209
210
211
212
213
    def __init__(self, activation_sparsity: float, approximate: str = "none"):
        super().__init__()
        # Gelu.
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
214
215
216
217
218
219
220
        if current_platform.is_rocm() and approximate == "tanh":
            # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
            logger.warning_once(
                "[ROCm] Pytorch's native GELU with tanh approximation is currently "
                "unstable and produces garbage. Fallback to 'none' approximation."
            )
            self.approximate = "none"
Robert Shaw's avatar
Robert Shaw committed
221
222
223

        # Sparsity.
        if activation_sparsity == 0.0:
224
225
            raise ValueError("activation_sparsity is 0.0. Please use GeluAndMul.")
        target_sparsity_tensor = torch.tensor(activation_sparsity, dtype=torch.float32)
Robert Shaw's avatar
Robert Shaw committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        normal_dist = torch.distributions.normal.Normal(0, 1)
        self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)

    def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
        """Get % sparse percentile of the Gaussian distribution."""
        # NOTE(rob): for TP>1, we could all-gather to get the means/std.
        # But we do not do this because in expectation they are the same
        # and in practice the eval scores are good without gathering.
        mean = torch.mean(x, dim=-1, keepdim=True)
        std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
        cutoff_x = mean + std * self.std_multiplier
        return nn.functional.relu(x - cutoff_x)

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        out = self._gaussian_topk(x[..., :d])
        out = F.gelu(out, approximate=self.approximate)
        return out * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
# --8<-- [start:gelu]
@CustomOp.register("gelu")
class GELU(CustomOp):
    # --8<-- [end:gelu]

    def __init__(self):
        super().__init__()
        if current_platform.get_cpu_architecture() == CpuArchEnum.ARM and hasattr(
            torch.ops._C, "activation_lut_bf16"
        ):
            self.op = torch.ops._C.activation_lut_bf16
        else:
            self.op = None

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        return F.gelu(x, approximate="none")

    def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
        if self.op and x.dtype == torch.bfloat16 and x.is_contiguous():
            out = torch.empty_like(x)
            self.op(out, x, "gelu")
            return out
        return self.forward_native(x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


278
# --8<-- [start:gelu_and_mul]
279
@CustomOp.register("gelu_and_mul")
280
class GeluAndMul(CustomOp):
281
282
283
284
285
286
287
288
289
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

290
291
    # --8<-- [end:gelu_and_mul]

292
293
294
295
296
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
297
298
299
300
301
        if (
            current_platform.is_cuda_alike()
            or current_platform.is_cpu()
            or current_platform.is_xpu()
        ):
302
303
304
305
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
306
307
308
309
310
311
        if current_platform.is_rocm() and approximate == "tanh":
            logger.warning_once(
                "[ROCm] PyTorch's native GELU with tanh approximation is unstable "
                "with torch.compile. For native implementation, fallback to 'none' "
                "approximation. The custom kernel implementation is unaffected."
            )
312

313
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
314
        """PyTorch-native implementation equivalent to forward()."""
315
316
317
318
        # TODO: [ROCm] PyTorch's native GELU with tanh is unstable with torch.compile
        approximate = self.approximate
        if current_platform.is_rocm() and approximate == "tanh":
            approximate = "none"
319
        d = x.shape[-1] // 2
320
        return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
321

322
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
323
        d = x.shape[-1] // 2
324
        output_shape = x.shape[:-1] + (d,)
325
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
326
        self.op(out, x)
327
328
        return out

329
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
330
        return self.forward_cuda(x)
331

332
    def extra_repr(self) -> str:
333
        return f"approximate={repr(self.approximate)}"
334

335

336
# --8<-- [start:swigluoai_and_mul]
337
338
339
@CustomOp.register("swigluoai_and_mul")
class SwigluOAIAndMul(CustomOp):
    # https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
340
341
    # --8<-- [end:swigluoai_and_mul]

342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    def __init__(self, alpha: float = 1.702, limit: float = 7.0):
        super().__init__()
        self.alpha = alpha
        self.limit = limit

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""

        gate, up = x[..., ::2], x[..., 1::2]
        gate = gate.clamp(min=None, max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        glu = gate * torch.sigmoid(gate * self.alpha)
        gated_output = (up + 1) * glu
        return gated_output

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
359
        output_shape = x.shape[:-1] + (d,)
360
361
362
363
364
365
366
367
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
        return out

    def extra_repr(self) -> str:
        return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"


csy0225's avatar
csy0225 committed
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
# --8<-- [start:swiglustep_and_mul]
@CustomOp.register("swiglustep_and_mul")
class SwigluStepAndMul(CustomOp):
    """An activation function for SwiGLU with clamping.

    Computes x -> silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit)
    where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, limit: float = 7.0):
        super().__init__()
        if limit is None:
            raise ValueError("SwigluStepAndMul requires limit to be set.")
        self.limit = limit

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        gate, up = x.chunk(2, dim=-1)
        gate = F.silu(gate)
        gate = gate.clamp(max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        return gate * up

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = x.shape[:-1] + (d,)
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        swiglustep_and_mul_triton(out, x, self.limit)
        return out

    def extra_repr(self) -> str:
        return f"limit={repr(self.limit)}"


406
# --8<-- [start:gelu_new]
407
@CustomOp.register("gelu_new")
408
class NewGELU(CustomOp):
409
410
    # --8<-- [end:gelu_new]

411
412
    def __init__(self):
        super().__init__()
413
414
415
416
417
        if (
            current_platform.is_cuda_alike()
            or current_platform.is_cpu()
            or current_platform.is_xpu()
        ):
418
419
            self.op = torch.ops._C.gelu_new

420
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
421
422
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
423
        return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
424

425
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
426
        out = torch.empty_like(x)
427
        self.op(out, x)
428
429
        return out

430
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
431
        return self.forward_cuda(x)
432

433

434
# --8<-- [start:gelu_fast]
435
@CustomOp.register("gelu_fast")
436
class FastGELU(CustomOp):
437
438
    # --8<-- [end:gelu_fast]

439
440
    def __init__(self):
        super().__init__()
441
442
443
444
445
        if (
            current_platform.is_cuda_alike()
            or current_platform.is_cpu()
            or current_platform.is_xpu()
        ):
446
447
            self.op = torch.ops._C.gelu_fast

448
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
449
        """PyTorch-native implementation equivalent to forward()."""
450
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
451

452
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
453
        out = torch.empty_like(x)
454
        self.op(out, x)
455
456
        return out

457
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
458
        return self.forward_cuda(x)
459

460

461
# --8<-- [start:quick_gelu]
462
@CustomOp.register("quick_gelu")
463
464
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
465
466
    # --8<-- [end:quick_gelu]

467
468
    def __init__(self):
        super().__init__()
469
470
471
472
473
        if (
            current_platform.is_cuda_alike()
            or current_platform.is_cpu()
            or current_platform.is_xpu()
        ):
474
475
            self.op = torch.ops._C.gelu_quick

476
477
478
479
480
481
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
482
        self.op(out, x)
483
484
        return out

485
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
486
        return self.forward_cuda(x)
487

488

489
# --8<-- [start:relu2]
490
@CustomOp.register("relu2")
491
492
493
494
495
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

496
497
    # --8<-- [end:relu2]

498
499
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
500
        return torch.square(F.relu(x))
501
502

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
503
        # TODO : implement cuda kernels
504
505
506
        return self.forward_native(x)


507
# --8<-- [start:xielu]
508
509
510
511
512
513
514
515
@CustomOp.register("xielu")
class XIELU(CustomOp):
    """
    Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
    If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
    Otherwise, we emit a single warning and use xIELU Python
    """

516
517
    # --8<-- [end:xielu]

518
519
520
521
522
523
524
525
526
527
528
    def __init__(
        self,
        alpha_p_init: float = 0.8,
        alpha_n_init: float = 0.8,
        beta: float = 0.5,
        eps: float = -1e-6,
        dtype: torch.dtype = torch.bfloat16,
        with_vector_loads: bool = False,
    ):
        super().__init__()
        self.alpha_p = nn.Parameter(
529
530
531
532
            torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
                0
            )
        )
533
534
        self.alpha_n = nn.Parameter(
            torch.log(
535
536
537
                torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
            ).unsqueeze(0)
        )
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
        self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
        self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
        self.with_vector_loads = with_vector_loads
        # Temporary until xIELU CUDA fully implemented
        self._beta_scalar = float(self.beta.detach().cpu().float().item())
        self._eps_scalar = float(self.eps.detach().cpu().float().item())

        self._xielu_cuda_obj = None
        try:
            import xielu.ops  # noqa: F401

            self._xielu_cuda_obj = torch.classes.xielu.XIELU()
            msg = "Using experimental xIELU CUDA."
            try:
                from torch._dynamo import allow_in_graph

                self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
                msg += " Enabled torch._dynamo for xIELU CUDA."
            except Exception as err:
557
558
559
560
                msg += (
                    f" Could not enable torch._dynamo for xIELU ({err}) - "
                    "this may result in slower performance."
                )
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
                self._xielu_cuda_fn = self._xielu_cuda
            logger.warning_once(msg)
        except Exception as err:
            logger.warning_once(
                "CUDA-fused xIELU not available (%s) –"
                " falling back to a Python version.\n"
                "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
                str(err),
            )

    def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
        alpha_p = nn.functional.softplus(self.alpha_p)
        alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
        return torch.where(
            x > 0,
            alpha_p * x * x + self.beta * x,
577
            (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
578
579
580
581
        )

    def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
        """Firewall function to prevent torch.compile from seeing .item()"""
582
        assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
        original_shape = x.shape
        # CUDA kernel expects 3D tensors, reshape if needed
        while x.dim() < 3:
            x = x.unsqueeze(0)
        if x.dim() > 3:
            x = x.view(-1, 1, x.size(-1))
        if original_shape != x.shape:
            logger.warning_once(
                "Warning: xIELU input tensor expects 3 dimensions"
                " but got (shape: %s). Reshaping to (shape: %s).",
                original_shape,
                x.shape,
            )
        result = self._xielu_cuda_obj.forward(
            x,
            self.alpha_p,
            self.alpha_n,
            # Temporary until xIELU CUDA fully implemented ->
            # self.{beta,eps}.item()
            self._beta_scalar,
            self._eps_scalar,
            self.with_vector_loads,
        )
        return result.view(original_shape)

608
    def forward_native(self, input: torch.Tensor) -> torch.Tensor:
609
610
611
612
613
614
615
616
617
        if self._xielu_cuda_obj is not None and input.is_cuda:
            if not torch._dynamo.is_compiling():
                return self._xielu_cuda_fn(input)
            else:
                logger.warning_once(
                    "torch._dynamo is compiling, using Python version of xIELU."
                )
        return self._xielu_python(input)

618
619
620
    def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
        return self.forward_native(input)

621

622
623
624
625
626
627
628
629
630
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
631
632
        intermediate_size: int,
        input_is_parallel: bool = True,
633
        params_dtype: torch.dtype | None = None,
634
635
636
    ):
        super().__init__()
        self.act = act_module
637
        self.input_is_parallel = input_is_parallel
638
639
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
640
            intermediate_size_per_partition = divide(intermediate_size, tp_size)
641
642
643
644
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
645
        self.scales = nn.Parameter(
646
647
            torch.empty(intermediate_size_per_partition, dtype=params_dtype)
        )
648
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
649

650
    def forward(self, x: torch.Tensor) -> torch.Tensor:
651
652
        return self.act(x) / self.scales

653
654
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
655
656
657
658
659
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
660
661
662
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

663

664
665
_ACTIVATION_REGISTRY = LazyDict(
    {
666
        "gelu": lambda: GELU(),
667
668
        "gelu_fast": lambda: FastGELU(),
        "gelu_new": lambda: NewGELU(),
669
670
671
672
673
674
675
676
677
678
        "gelu_pytorch_tanh": lambda: (
            # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
            logger.warning_once(
                "[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
                "Falling back to GELU(approximate='none')."
            ),
            nn.GELU(approximate="none"),
        )[1]
        if current_platform.is_rocm()
        else nn.GELU(approximate="tanh"),
679
680
681
682
683
684
685
686
687
        "relu": lambda: nn.ReLU(),
        "relu2": lambda: ReLUSquaredActivation(),
        "silu": lambda: nn.SiLU(),
        "quick_gelu": lambda: QuickGELU(),
        "tanh": lambda: nn.Tanh(),
        "sigmoid": lambda: nn.Sigmoid(),
        "xielu": lambda: XIELU(),
    }
)
688
689


690
def get_act_fn(act_fn_name: str) -> nn.Module:
691
    """Get an activation function by name."""
692
    act_fn_name = act_fn_name.lower()
693
694
695
696
697
698
699

    if act_fn_name.startswith("torch.nn.modules."):
        activation_name = act_fn_name.split(".")[-1]
        if activation_name == "identity":
            return nn.Identity()
        act_fn_name = activation_name

700
    if act_fn_name not in _ACTIVATION_REGISTRY:
701
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
702

703
    return _ACTIVATION_REGISTRY[act_fn_name]
704
705


706
707
708
709
710
711
712
713
_ACTIVATION_AND_MUL_REGISTRY = LazyDict(
    {
        "gelu": lambda: GeluAndMul(),
        "silu": lambda: SiluAndMul(),
        "geglu": lambda: GeluAndMul(),
        "swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
    }
)
714
715


716
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
717
718
719
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
720
        raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
721

722
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]