softmax.py 256 Bytes
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
import torch
import torch.nn as nn

from liger_kernel.ops import LigerSoftmaxFunction


class LigerSoftmax(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor):
        return LigerSoftmaxFunction.apply(x)