test_lightning_attention_decode.py 2.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import pytest
import torch
from sgl_kernel import lightning_attention_decode


def naive_lightning_attention_decode(q, k, v, past_kv, slope):
    """Naive implementation of lightning attention decode"""
    original_dtype = q.dtype
    ratio = torch.exp(-slope)  # [h, 1, 1]

    kv = past_kv
    b, h, n, d = q.shape

    output = []
    for i in range(n):
        kv = ratio * kv.to(torch.float32) + torch.einsum(
            "... n d, ... n e -> ... d e",
            k[:, :, i : i + 1],
            v[:, :, i : i + 1],
        )
        qkv = torch.einsum(
            "... n e, ... e d -> ... n d",
            q[:, :, i : i + 1].to(torch.float32),
            kv.to(torch.float32),
        )
        output.append(qkv)
27
    output = torch.cat(output, dim=-2)
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

    return output.to(original_dtype), kv


configs = [
    # (batch_size, num_heads, dim, embed_dim)
    (1, 8, 64, 64),
    (2, 8, 64, 64),
    (1, 32, 32, 64),
    (2, 32, 32, 64),
    (4, 32, 64, 64),
    (4, 32, 64, 64),
    (16, 64, 96, 96),
    (64, 64, 96, 96),
]

dtypes = [torch.float32, torch.float16, torch.bfloat16]


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("batch_size,num_heads,dim,embed_dim", configs)
def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim):
    device = torch.device("cuda")

    q = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype)
    k = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype)
    v = torch.randn(batch_size, num_heads, 1, embed_dim, device=device, dtype=dtype)
    past_kv = torch.randn(batch_size, num_heads, dim, embed_dim, device=device)
    slope = torch.randn(num_heads, 1, 1, device=device)

    ref_output, ref_new_kv = naive_lightning_attention_decode(q, k, v, past_kv, slope)

    output = torch.empty_like(ref_output)
    new_kv = torch.empty_like(ref_new_kv)
    lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)

    rtol = 1e-2
    atol = 1e-2

    torch.testing.assert_close(
        output,
        ref_output,
        rtol=rtol,
        atol=atol,
        msg=f"Output mismatch for batch_size={batch_size}, num_heads={num_heads}, "
        f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}",
    )

    torch.testing.assert_close(
        new_kv,
        ref_new_kv,
        rtol=rtol,
        atol=atol,
        msg=f"New KV mismatch for batch_size={batch_size}, num_heads={num_heads}, "
        f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}",
    )
85
86
87
88


if __name__ == "__main__":
    pytest.main([__file__])