test_sparsemax.py 3.61 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import pytest
import torch

from test.utils import assert_verbose_allclose
from test.utils import set_seed

from liger_kernel.transformers.functional import liger_sparsemax
from liger_kernel.transformers.sparsemax import LigerSparsemax
from liger_kernel.utils import infer_device

device = infer_device()


def torch_sparsemax(input_tensor: torch.Tensor, dim: int = -1) -> torch.Tensor:
    input_dims = input_tensor.dim()
    if dim < 0:
        dim = input_dims + dim
    input_sorted, _ = torch.sort(input_tensor, dim=dim, descending=True)
    cumsum_input = torch.cumsum(input_sorted, dim=dim)
    input_size = input_tensor.size(dim)
    range_tensor = torch.arange(1, input_size + 1, device=input_tensor.device, dtype=input_tensor.dtype)
    shape = [1] * input_dims
    shape[dim] = input_size
    range_tensor = range_tensor.view(shape)
    k_bound = 1 + range_tensor * input_sorted
    support = k_bound > cumsum_input
    k = support.sum(dim=dim, keepdim=True).clamp(min=1)
    support_sum = (input_sorted * support).sum(dim=dim, keepdim=True)
    tau = (support_sum - 1) / k
    return torch.clamp(input_tensor - tau, min=0)


@pytest.mark.parametrize(
    "batch_size, seq_len, features",
    [
        (2, 128, 512),
        (5, 123, 123),
    ],
)
@pytest.mark.parametrize("dim", [-1, 1])
@pytest.mark.parametrize(
    "dtype, atol, rtol",
    [(torch.float32, 1e-5, 1e-5)],
)
def test_liger_sparsemax_correctness(batch_size, seq_len, features, dim, dtype, atol, rtol):
    set_seed(0)
    shape = (batch_size, seq_len, features)
    if dim >= len(shape) or dim < -len(shape):
        pytest.skip("invalid dim")
    if shape[dim if dim >= 0 else len(shape) + dim] <= 1:
        pytest.skip("trivial dim")

    x = torch.randn(*shape, dtype=dtype, device=device)
    lx = x.clone().requires_grad_(True)
    tx = x.clone().requires_grad_(True)

    model = LigerSparsemax(dim=dim).to(device)
    out_l = model(lx)
    out_t = torch_sparsemax(tx, dim=dim)
    assert_verbose_allclose(out_l, out_t, atol=atol, rtol=rtol)

    sum_l = out_l.sum(dim=dim)
    sum_t = out_t.sum(dim=dim)
    assert_verbose_allclose(sum_l, torch.ones_like(sum_l), atol=atol * 10, rtol=rtol * 10)
    assert_verbose_allclose(sum_t, torch.ones_like(sum_t), atol=atol * 10, rtol=rtol * 10)

    g = torch.randn_like(x)
    out_l.backward(g)
    out_t.backward(g)
    assert_verbose_allclose(lx.grad, tx.grad, atol=atol, rtol=rtol)


@pytest.mark.parametrize(
    "batch_size, seq_len, features",
    [
        (2, 128, 512),
        (5, 123, 123),
    ],
)
@pytest.mark.parametrize("dim", [-1, 1])
@pytest.mark.parametrize(
    "dtype, atol, rtol",
    [
        (torch.float32, 1e-5, 1e-5),
    ],
)
def test_liger_sparsemax_functional_correctness(batch_size, seq_len, features, dim, dtype, atol, rtol):
    set_seed(0)
    shape = (batch_size, seq_len, features)
    if dim >= len(shape) or dim < -len(shape):
        pytest.skip("invalid dim")
    if shape[dim if dim >= 0 else len(shape) + dim] <= 1:
        pytest.skip("trivial dim")

    x = torch.randn(*shape, dtype=dtype, device=device)
    lx = x.clone().requires_grad_(True)
    tx = x.clone().requires_grad_(True)

    out_l = liger_sparsemax(lx, dim=dim)
    out_t = torch_sparsemax(tx, dim=dim)
    assert_verbose_allclose(out_l, out_t, atol=atol, rtol=rtol)

    sum_l = out_l.sum(dim=dim)
    sum_t = out_t.sum(dim=dim)
    assert_verbose_allclose(sum_l, torch.ones_like(sum_l), atol=atol * 10, rtol=rtol * 10)
    assert_verbose_allclose(sum_t, torch.ones_like(sum_t), atol=atol * 10, rtol=rtol * 10)

    g = torch.randn_like(x)
    out_l.backward(g)
    out_t.backward(g)
    assert_verbose_allclose(lx.grad, tx.grad, atol=atol, rtol=rtol)