sparsemax.py 393 Bytes
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import torch.nn as nn

from liger_kernel.ops import LigerSparsemaxFunction


class LigerSparsemax(nn.Module):
    def __init__(self, dim: int = -1):
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return LigerSparsemaxFunction.apply(x, self.dim)

    def extra_repr(self) -> str:
        return f"dim={self.dim}"