test_lightning_attn.py 8.49 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6

import pytest
import torch

7
from vllm.model_executor.layers.lightning_attn import linear_decode_forward_triton
8
from vllm.utils.torch_utils import set_random_seed
9
10
11
12
13
14
15
16
17
18

NUM_HEADS = [4, 8]
HEAD_SIZES = [64]
BATCH_SIZES = [1, 2]
SEQ_LENGTHS = [16]
DTYPES = [torch.float32]


def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
    """Reference implementation of lightning attention core algorithm
19
20

    The difference from the main implementation is that this processes
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    each step sequentially, instead of using parallelized triton kernels
    """
    B, H, S, D = q.shape
    E = v.shape[-1]
    dtype = q.dtype
    output = torch.zeros((B, H, S, E), dtype=dtype, device=q.device)

    # Use clone() to ensure an independent copy
    if kv_history is None:
        kv_cache = torch.zeros((B, H, D, E), dtype=dtype, device=q.device)
    else:
        kv_cache = kv_history.clone()

    # More efficient implementation
    # Convert decay factors to matrix form
36
    decay = torch.exp(-ed).view(1, -1, 1, 1) if ed.dim() == 1 else torch.exp(-ed)
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

    for b in range(B):
        for step in range(S):
            # Process all heads at once for this position
            q_bs = q[b, :, step]  # [H, D]
            k_bs = k[b, :, step]  # [H, D]
            v_bs = v[b, :, step]  # [H, E]

            # Calculate KV outer products for all heads
            for h in range(H):
                # Calculate KV outer product
                kv_outer = torch.outer(k_bs[h], v_bs[h])

                # Update KV cache with decay
                # Note: Using the same order as in the Triton kernel
                kv_cache[b, h] = decay[0, h, 0, 0] * kv_cache[b, h] + kv_outer

                # Calculate attention output
                output[b, h, step] = torch.matmul(q_bs[h], kv_cache[b, h])

    # Match the shape returned by the actual implementation
    # The actual implementation returns a tensor of shape [B, H, 2, D, E]
    # where dimension 2 contains both KV and KV history
    kv_reshaped = kv_cache.unsqueeze(2)  # [B, H, 1, D, E]
61
    final_kv_cache = torch.cat([kv_reshaped, kv_reshaped], dim=2)  # [B, H, 2, D, E]
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

    return output, final_kv_cache


def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx):
    """Reference implementation: linear attention decode function"""
    B, H, _, D = q.shape
    output = torch.zeros(B, H * D, dtype=q.dtype, device=q.device)

    # Calculate decay factors once (more efficient)
    decay = torch.exp(-slope_rate).view(-1, 1, 1)  # [H, 1, 1]

    # Process each batch
    for b in range(B):
        slot_id = slot_idx[b].item()

        # Skip padding positions
        if slot_id == -1:
            continue

        # Process all heads at once for this batch
        q_b = q[b, :, 0]  # [H, D]
        k_b = k[b, :, 0]  # [H, D]
        v_b = v[b, :, 0]  # [H, D]

        # Process each attention head
        for h in range(H):
            # Get current query, key and value
            q_bh = q_b[h]
            k_bh = k_b[h]
            v_bh = v_b[h]

            # Get cache
            kv_cache_old = kv_caches[b, h]

            # Calculate new key-value outer product
            kv_outer = torch.outer(k_bh, v_bh)

            # Apply decay and update cache
            kv_new = kv_outer + decay[h, 0, 0] * kv_cache_old

            # Calculate output
            out_h = torch.matmul(q_bh, kv_new)

            # Update output and cache
107
            output[b, h * D : (h + 1) * D] = out_h
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            kv_caches[b, h] = kv_new

    return output


@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_linear_decode_forward_triton(
    batch_size: int,
    num_heads: int,
    head_size: int,
    dtype: torch.dtype,
):
    torch.set_default_device("cuda")
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
127
    set_random_seed(42)
128
129
130
131
132
    base = 0.01
    q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
    k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
    v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)

133
134
135
    kv_caches = base * torch.randn(
        batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
    )
136
137
138
139
140
141
142
143
144

    kv_caches_copy = kv_caches.clone()

    slope_rate = torch.zeros(num_heads, device="cuda")
    for h in range(num_heads):
        slope_rate[h] = 0.1 * (h + 1)

    slot_idx = torch.arange(batch_size, device="cuda")

145
146
147
    triton_output = linear_decode_forward_triton(
        q, k, v, kv_caches, slope_rate, slot_idx
    )
148

149
150
151
152
    reference_output = reference_linear_decode(
        q, k, v, kv_caches_copy, slope_rate, slot_idx
    )
    torch.testing.assert_close(triton_output, reference_output, rtol=1e-1, atol=1e-1)
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1)

    assert triton_output.shape == (batch_size, num_heads * head_size)


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_linear_decode_forward_triton_with_padding(
    num_heads: int,
    head_size: int,
    dtype: torch.dtype,
):
    torch.set_default_device("cuda")
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
170
    set_random_seed(42)
171
172
173
174
175
176
177

    batch_size = 4
    base = 0.01
    q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
    k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
    v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)

178
179
180
    kv_caches = base * torch.randn(
        batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
    )
181
182
183
184
185
186
187
188
189

    kv_caches_copy = kv_caches.clone()

    slope_rate = torch.zeros(num_heads, device="cuda")
    for h in range(num_heads):
        slope_rate[h] = 0.1 * (h + 1)

    slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")

190
191
192
    triton_output = linear_decode_forward_triton(
        q, k, v, kv_caches, slope_rate, slot_idx
    )
193

194
195
196
    reference_output = reference_linear_decode(
        q, k, v, kv_caches_copy, slope_rate, slot_idx
    )
197

198
    padding_mask = (slot_idx != -1).unsqueeze(1).expand(-1, num_heads * head_size)
199
200
201
202
203
204
205
206
207
208

    triton_masked = triton_output[padding_mask]
    reference_masked = reference_output[padding_mask]

    atol, rtol = 1.5e-1, 1.5e-1

    valid_indices = slot_idx != -1

    for i in range(batch_size):
        if valid_indices[i] > 0:
209
210
211
            torch.testing.assert_close(
                kv_caches[i], kv_caches_copy[i], rtol=rtol, atol=atol
            )
212

213
    torch.testing.assert_close(triton_masked, reference_masked, rtol=rtol, atol=atol)
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

    assert triton_output.shape == (batch_size, num_heads * head_size)


@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENGTHS)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_lightning_attention_reference(
    batch_size: int,
    num_heads: int,
    head_size: int,
    seq_len: int,
    dtype: torch.dtype,
):
    torch.set_default_device("cuda")
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
234
    set_random_seed(42)
235
236

    base = 0.01
237
238
239
    q = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
    k = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
    v = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
240
241
242
243
244

    ed = torch.zeros(num_heads, device="cuda")
    for h in range(num_heads):
        ed[h] = 0.1 * (h + 1)

245
246
247
    kv_history = base * torch.randn(
        batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
    )
248
249
250
251

    kv_history_clone = kv_history.clone()

    ref_output, ref_kv_cache = reference_lightning_attention(
252
253
        q, k, v, ed, 256, kv_history
    )
254
255

    from vllm.model_executor.layers.lightning_attn import lightning_attention
256

257
    actual_output, actual_kv_cache = lightning_attention(
258
259
        q, k, v, ed, 256, kv_history_clone
    )
260
261
262

    atol, rtol = 1.5e-1, 1.5e-1
    torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol)
263
    torch.testing.assert_close(ref_kv_cache, actual_kv_cache, rtol=rtol, atol=atol)
264
265
266

    assert ref_output.shape == (batch_size, num_heads, seq_len, head_size)
    assert ref_kv_cache.shape == actual_kv_cache.shape