sparsemax.py 4.93 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from typing import Tuple

import torch
import triton
import triton.language as tl

from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import ensure_contiguous


@triton.jit
def _sparsemax_forward_kernel(
    x_ptr,
    x_stride_row,
    sorted_x_ptr,
    sorted_x_stride_row,
    o_ptr,
    o_stride_row,
    n_cols,
    BLOCK_SIZE: tl.constexpr,
    num_warps: tl.constexpr,
):
    pid_row = tl.program_id(0)
    ptr_x_data_row = x_ptr + pid_row * x_stride_row
    ptr_sorted_x_data_row = sorted_x_ptr + pid_row * sorted_x_stride_row
    ptr_output_row = o_ptr + pid_row * o_stride_row

    offs = tl.arange(0, BLOCK_SIZE)
    mask = offs < n_cols

    z_sorted_block = tl.load(
        ptr_sorted_x_data_row + offs,
        mask=mask,
        other=-float("inf"),
        cache_modifier=".cg",
    ).to(tl.float32)

    z_valid = tl.where(mask, z_sorted_block, 0.0)
    cssv = tl.cumsum(z_valid, 0)

    r = (offs + 1).to(tl.float32)
    t_vec = (cssv - 1.0) / r

    support = (z_sorted_block > t_vec) & mask

    k_int = tl.sum(support.to(tl.int32), 0)
    k_clamped_int = tl.maximum(k_int, 1)
    k = k_clamped_int.to(tl.float32)

    s = tl.sum(tl.where(support, z_sorted_block, 0.0), 0)

    tau = (s - 1.0) / k

    x_block = tl.load(
        ptr_x_data_row + offs,
        mask=mask,
        other=0.0,
        cache_modifier=".cg",
    ).to(tl.float32)

    y = tl.maximum(x_block - tau, 0.0)

    tl.store(
        ptr_output_row + offs,
        y.to(ptr_output_row.dtype.element_ty),
        mask=mask,
        cache_modifier=".cs",
    )


@triton.jit
def _sparsemax_backward_kernel(
    o_ptr, go_ptr, gi_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr, num_warps: tl.constexpr
):
    row = tl.program_id(0)
    o_row = o_ptr + row * stride
    go_row = go_ptr + row * stride
    gi_row = gi_ptr + row * stride

    offs = tl.arange(0, BLOCK_SIZE)

    supp_cnt = tl.zeros((), tl.float32)
    go_sum = tl.zeros((), tl.float32)

    for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
        offs_iter = i * BLOCK_SIZE + offs
        mask_iter = offs_iter < n_cols
        o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
        go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
        supp = o_val > 0.0
        go_sum += tl.sum(tl.where(supp, go_val, 0.0))
        supp_cnt += tl.sum(supp.to(tl.float32))

    for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
        offs_iter = i * BLOCK_SIZE + offs
        mask_iter = offs_iter < n_cols
        o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
        go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
        supp = o_val > 0.0
        gi_val = tl.where(
            supp,
            go_val - tl.cast(go_sum / tl.maximum(supp_cnt, 1e-6), gi_row.dtype.element_ty).to(tl.float32),
            0.0,
        )
        tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")


def _sparsemax_forward(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
    if dim < 0:
        dim += x.dim()
    x_sw = x.transpose(dim, -1).contiguous()
    n_cols = x_sw.size(-1)
    n_rows = x_sw.numel() // n_cols
    x_flat = x_sw.view(n_rows, n_cols)
    x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values

    BLOCK_SIZE, num_warps = calculate_settings(n_cols)
    out_flat = torch.empty_like(x_flat)
    grid = (n_rows,)
    _sparsemax_forward_kernel[grid](
        x_flat,
        x_flat.stride(0),
        x_sorted_flat,
        x_sorted_flat.stride(0),
        out_flat,
        out_flat.stride(0),
        n_cols,
        BLOCK_SIZE=BLOCK_SIZE,
        num_warps=num_warps,
    )

    y = out_flat.view_as(x_sw).transpose(dim, -1)
    return y, out_flat


def _sparsemax_backward(
    grad_out: torch.Tensor,
    out_flat: torch.Tensor,
    dim: int,
) -> torch.Tensor:
    grad_sw = grad_out.transpose(dim, -1).contiguous()
    n_cols = grad_sw.size(-1)
    n_rows = grad_sw.numel() // n_cols
    go_flat = grad_sw.view(n_rows, n_cols)

    BLOCK_SIZE, num_warps = calculate_settings(n_cols)
    dx_flat = torch.empty_like(go_flat)
    grid = (n_rows,)
    _sparsemax_backward_kernel[grid](
        out_flat,
        go_flat,
        dx_flat,
        out_flat.stride(0),
        n_cols,
        BLOCK_SIZE=BLOCK_SIZE,
        num_warps=num_warps,
    )

    dx = dx_flat.view_as(grad_sw).transpose(dim, -1)
    return dx


class LigerSparsemaxFunction(torch.autograd.Function):
    @staticmethod
    @ensure_contiguous
    def forward(ctx, x: torch.Tensor, dim: int):
        y, out_flat = _sparsemax_forward(x, dim)
        ctx.save_for_backward(out_flat)
        ctx.dim = dim
        return y

    @staticmethod
    @ensure_contiguous
    def backward(ctx, grad_out: torch.Tensor):
        (out_flat,) = ctx.saved_tensors
        dx = _sparsemax_backward(grad_out, out_flat, ctx.dim)
        return dx, None