paged_attention_prefill.py 10.2 KB
Newer Older
1
2
import ctypes
from ctypes import c_uint64
3
4

import torch
5
6
from libinfiniop import (
    LIBINFINIOP,
7
8
9
    InfiniDeviceNames,
    InfiniDtype,
    InfiniDtypeNames,
10
    TestTensor,
11
    TestWorkspace,
12
13
    check_error,
    debug,
14
15
    get_args,
    get_test_devices,
16
17
    get_tolerance,
    infiniopOperatorDescriptor_t,
18
19
    profile_operation,
    test_operator,
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
)

# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
_TEST_CASES = [
    # num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds
    (1, 1, 1, 128, 8, 16, 1),
    (1, 4, 4, 128, 8, 16, 4),
    (2, 8, 8, 128, 16, 32, 2),
    (4, 16, 16, 128, 8, 64, 3),
    (8, 64, 64, 128, 8, 16, 5),
    (16, 128, 128, 128, 8, 16, 4),
]

_TENSOR_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]

_TOLERANCE_MAP = {
    InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5},
    InfiniDtype.F16: {"atol": 1e-2, "rtol": 1e-2},
    InfiniDtype.BF16: {"atol": 2e-2, "rtol": 2e-2},
}

DEBUG = False
PROFILE = False
NUM_PRERUN = 5
NUM_ITERATIONS = 10


# ==============================================================================
# Helper Classes & Reference Implementation
# ==============================================================================
class SimpleCacheManager:
    def __init__(self, num_blocks, block_size):
        self.num_blocks = num_blocks
        self.block_size = block_size
        self.free_blocks = list(range(num_blocks))
        self.request_to_blocks = {}
        self.request_to_len = {}

    def allocate_slots(self, request_id, num_new_tokens):
        if request_id not in self.request_to_len:
            self.request_to_len[request_id] = 0
            self.request_to_blocks[request_id] = []

        start_pos = self.request_to_len[request_id]
        new_total_len = start_pos + num_new_tokens
        needed_blocks = (new_total_len + self.block_size - 1) // self.block_size
        added_blocks = needed_blocks - len(self.request_to_blocks[request_id])

        for _ in range(added_blocks):
            self.request_to_blocks[request_id].append(self.free_blocks.pop(0))

        self.request_to_len[request_id] = new_total_len
        return self.request_to_blocks[request_id], new_total_len


def ref_paged_attention_multi_turn(
78
    query_new, k_cache, v_cache, block_tables, seq_lens, cum_seq_lens_q, scale
79
80
81
):
    block_size = k_cache.shape[2]
    outputs = torch.zeros_like(query_new)
82
83
84
    num_seqs = len(cum_seq_lens_q) - 1
    for i in range(num_seqs):
        num_new = cum_seq_lens_q[i + 1].item() - cum_seq_lens_q[i].item()
85
86
        total_len = seq_lens[i].item()
        cache_len = seq_lens[i].item() - num_new
87
88
89
90
91
92
93
94
95
96
97

        table = block_tables[i]
        keys_all, values_all = [], []
        for j in range(total_len):
            b_id = table[j // block_size].item()
            off = j % block_size
            keys_all.append(k_cache[b_id, :, off, :])
            values_all.append(v_cache[b_id, :, off, :])

        K = torch.stack(keys_all, dim=0)
        V = torch.stack(values_all, dim=0)
98
        Q = query_new[cum_seq_lens_q[i] : cum_seq_lens_q[i + 1], :, :]
99
100
101
102
103

        scores = torch.einsum("qhd,khd->hqk", Q, K).float() * scale

        mask = torch.full((num_new, total_len), float("-inf"), device=Q.device)
        for q_idx in range(num_new):
104
            mask[q_idx, : cache_len + q_idx + 1] = 0.0
105
106
107
108
109

        scores = scores + mask.unsqueeze(0)
        attn_weights = torch.softmax(scores, dim=-1).to(Q.dtype)
        out = torch.einsum("hqk,khd->qhd", attn_weights, V)

110
        outputs[cum_seq_lens_q[i] : cum_seq_lens_q[i + 1], :, :] = out
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151

    return outputs


# ==============================================================================
# Test Operator Implementation
# ==============================================================================
def test(
    handle,
    device,
    num_seqs,
    num_heads,
    num_kv_heads,
    head_size,
    block_size,
    max_step_len,
    num_rounds,
    dtype=InfiniDtype.F16,
    sync=None,
):
    print(
        f"Testing PagedAttentionPrefill on {InfiniDeviceNames[device]} with "
        f"seqs:{num_seqs}, heads:{num_heads}, head_size:{head_size}, "
        f"block:{block_size}, max_step_len:{max_step_len}, num_rounds:{num_rounds}, dtype:{InfiniDtypeNames[dtype]}"
    )

    # 1. Initialize persistent resources
    num_blocks = 8192
    manager = SimpleCacheManager(num_blocks, block_size)
    scale = head_size**-0.5

    k_cache = TestTensor(
        (num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
    )
    v_cache = TestTensor(
        (num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
    )

    # Multi-turn testing loop
    for r in range(num_rounds):
        # Prepare dynamic inputs for this round
152
        query_lens_cpu = torch.randint(
153
154
155
            1, max_step_len + 1, (num_seqs,), dtype=torch.int64
        )

156
        q_total_tokens = query_lens_cpu.sum().item()
157
158
        q_packed_tensors = torch.zeros(q_total_tokens, num_heads, head_size)

159
        seq_lens_list = []
160
161
        all_block_tables = []

162
163
        cum_seq_lens_q_list = []
        cum_q_lens = 0
164
        for i in range(num_seqs):
165
            cum_seq_lens_q_list.append(cum_q_lens)
166

167
168
169
            cur_q_len = query_lens_cpu[i].item()
            table, total_len = manager.allocate_slots(i, cur_q_len)
            cur_seq_lens = total_len - cur_q_len
170
            seq_lens_list.append(total_len)
171
172
173
            all_block_tables.append(table)

            # Simulated KV insertion
174
175
176
177
            k_new = torch.randn(cur_q_len, num_kv_heads, head_size)
            v_new = torch.randn(cur_q_len, num_kv_heads, head_size)
            q_val = torch.randn(cur_q_len, num_heads, head_size)
            q_packed_tensors[cum_q_lens : cum_q_lens + cur_q_len] = q_val
178

179
            cum_q_lens = cum_q_lens + cur_q_len
180

181
182
            for t in range(cur_q_len):
                logical_pos = cur_seq_lens + t
183
184
185
186
187
                b_id = table[logical_pos // block_size]
                off = logical_pos % block_size
                k_cache.torch_tensor()[b_id, :, off, :] = k_new[t]
                v_cache.torch_tensor()[b_id, :, off, :] = v_new[t]

188
        cum_seq_lens_q_list.append(cum_q_lens)
189
190
191
192
193
194
195
196
197

        k_cache.actual_tensor().copy_(k_cache._torch_tensor)
        v_cache.actual_tensor().copy_(v_cache._torch_tensor)

        # 2. Wrap tensors for Infiniop
        q_new = TestTensor.from_torch(q_packed_tensors, dtype, device)
        out = TestTensor.from_torch(q_packed_tensors, dtype, device)
        out.actual_tensor().zero_()

198
199
        seq_lens = TestTensor.from_torch(
            torch.tensor(seq_lens_list, dtype=torch.int64), InfiniDtype.I64, device
200
201
        )

202
203
204
205
        cum_seq_lens_q = TestTensor.from_torch(
            torch.tensor(cum_seq_lens_q_list, dtype=torch.int64),
            InfiniDtype.I64,
            device,
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        )

        max_blocks = max(len(t) for t in all_block_tables)
        padded_tables = [t + [0] * (max_blocks - len(t)) for t in all_block_tables]
        block_tables = TestTensor.from_torch(
            torch.tensor(padded_tables, dtype=torch.int64), InfiniDtype.I64, device
        )

        # 3. Reference Calculation
        def torch_paged_attention_multi_turn():
            return ref_paged_attention_multi_turn(
                q_new.torch_tensor(),
                k_cache.torch_tensor(),
                v_cache.torch_tensor(),
                block_tables.torch_tensor(),
                seq_lens.torch_tensor(),
222
                cum_seq_lens_q.torch_tensor(),
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
                scale,
            )

        ans = torch_paged_attention_multi_turn()

        # 4. Infiniop Operator Execution
        descriptor = infiniopOperatorDescriptor_t()
        check_error(
            LIBINFINIOP.infiniopCreatePagedAttentionPrefillDescriptor(
                handle,
                ctypes.byref(descriptor),
                out.descriptor,
                q_new.descriptor,
                k_cache.descriptor,
                v_cache.descriptor,
                block_tables.descriptor,
                seq_lens.descriptor,
240
241
                cum_seq_lens_q.descriptor,
                None,
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
                scale,
            )
        )

        workspace_size = c_uint64(0)
        check_error(
            LIBINFINIOP.infiniopGetPagedAttentionPrefillWorkspaceSize(
                descriptor, ctypes.byref(workspace_size)
            )
        )
        workspace = TestWorkspace(workspace_size.value, device)

        def lib_attn():
            check_error(
                LIBINFINIOP.infiniopPagedAttentionPrefill(
                    descriptor,
                    workspace.data(),
                    workspace_size.value,
                    out.data(),
                    q_new.data(),
                    k_cache.data(),
                    v_cache.data(),
                    block_tables.data(),
                    seq_lens.data(),
266
                    cum_seq_lens_q.data(),
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
                    None,
                    None,
                )
            )

        lib_attn()
        if sync:
            sync()

        # 5. Validation
        atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
        if DEBUG:
            debug(out.actual_tensor(), ans, atol=atol, rtol=rtol)

        assert torch.allclose(out.actual_tensor(), ans, atol=atol, rtol=rtol)

        # Profiling
        if PROFILE:
            profile_operation(
                f"Torch_R{r}",
                lambda: torch_paged_attention_multi_turn(),
                device,
                NUM_PRERUN,
                NUM_ITERATIONS,
            )
            profile_operation(
                f"  Lib_R{r}", lambda: lib_attn(), device, NUM_PRERUN, NUM_ITERATIONS
            )

        check_error(
            LIBINFINIOP.infiniopDestroyPagedAttentionPrefillDescriptor(descriptor)
        )


# ==============================================================================
# Main Execution
# ==============================================================================
if __name__ == "__main__":
    args = get_args()

    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    for device in get_test_devices(args):
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)

    print("\033[92mTest passed!\033[0m")