import torch import ctypes from ctypes import c_uint64 from libinfiniop import ( LIBINFINIOP, TestTensor, get_test_devices, check_error, test_operator, get_args, debug, get_tolerance, profile_operation, InfiniDtype, InfiniDtypeNames, InfiniDeviceNames, infiniopOperatorDescriptor_t, TestWorkspace, ) # ============================================================================== # Configuration (Internal Use Only) # ============================================================================== _TEST_CASES = [ # num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds (1, 1, 1, 128, 8, 16, 1), (1, 4, 4, 128, 8, 16, 4), (2, 8, 8, 128, 16, 32, 2), (4, 16, 16, 128, 8, 64, 3), (8, 64, 64, 128, 8, 16, 5), (16, 128, 128, 128, 8, 16, 4), ] _TENSOR_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16] _TOLERANCE_MAP = { InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, InfiniDtype.F16: {"atol": 1e-2, "rtol": 1e-2}, InfiniDtype.BF16: {"atol": 2e-2, "rtol": 2e-2}, } DEBUG = False PROFILE = False NUM_PRERUN = 5 NUM_ITERATIONS = 10 # ============================================================================== # Helper Classes & Reference Implementation # ============================================================================== class SimpleCacheManager: def __init__(self, num_blocks, block_size): self.num_blocks = num_blocks self.block_size = block_size self.free_blocks = list(range(num_blocks)) self.request_to_blocks = {} self.request_to_len = {} def allocate_slots(self, request_id, num_new_tokens): if request_id not in self.request_to_len: self.request_to_len[request_id] = 0 self.request_to_blocks[request_id] = [] start_pos = self.request_to_len[request_id] new_total_len = start_pos + num_new_tokens needed_blocks = (new_total_len + self.block_size - 1) // self.block_size added_blocks = needed_blocks - len(self.request_to_blocks[request_id]) for _ in range(added_blocks): self.request_to_blocks[request_id].append(self.free_blocks.pop(0)) self.request_to_len[request_id] = new_total_len return self.request_to_blocks[request_id], new_total_len def ref_paged_attention_multi_turn( query_new, k_cache, v_cache, block_tables, seq_lens, new_lens, offset, scale ): block_size = k_cache.shape[2] outputs = torch.zeros_like(query_new) for i in range(len(offset) - 1): total_len = seq_lens[i].item() num_new = new_lens[i].item() history_len = total_len - num_new table = block_tables[i] keys_all, values_all = [], [] for j in range(total_len): b_id = table[j // block_size].item() off = j % block_size keys_all.append(k_cache[b_id, :, off, :]) values_all.append(v_cache[b_id, :, off, :]) K = torch.stack(keys_all, dim=0) V = torch.stack(values_all, dim=0) Q = query_new[offset[i] : offset[i] + num_new, :, :] scores = torch.einsum("qhd,khd->hqk", Q, K).float() * scale mask = torch.full((num_new, total_len), float("-inf"), device=Q.device) for q_idx in range(num_new): mask[q_idx, : history_len + q_idx + 1] = 0.0 scores = scores + mask.unsqueeze(0) attn_weights = torch.softmax(scores, dim=-1).to(Q.dtype) out = torch.einsum("hqk,khd->qhd", attn_weights, V) outputs[offset[i] : offset[i] + num_new, :, :] = out return outputs # ============================================================================== # Test Operator Implementation # ============================================================================== def test( handle, device, num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds, dtype=InfiniDtype.F16, sync=None, ): print( f"Testing PagedAttentionPrefill on {InfiniDeviceNames[device]} with " f"seqs:{num_seqs}, heads:{num_heads}, head_size:{head_size}, " f"block:{block_size}, max_step_len:{max_step_len}, num_rounds:{num_rounds}, dtype:{InfiniDtypeNames[dtype]}" ) # 1. Initialize persistent resources num_blocks = 8192 manager = SimpleCacheManager(num_blocks, block_size) scale = head_size**-0.5 k_cache = TestTensor( (num_blocks, num_kv_heads, block_size, head_size), None, dtype, device ) v_cache = TestTensor( (num_blocks, num_kv_heads, block_size, head_size), None, dtype, device ) # Multi-turn testing loop for r in range(num_rounds): # Prepare dynamic inputs for this round seq_lens_cpu = torch.randint( 1, max_step_len + 1, (num_seqs,), dtype=torch.int64 ) q_total_tokens = seq_lens_cpu.sum().item() q_packed_tensors = torch.zeros(q_total_tokens, num_heads, head_size) cache_lens_list = [] all_block_tables = [] offset_list = [] cur_offset = 0 for i in range(num_seqs): offset_list.append(cur_offset) cur_new_len = seq_lens_cpu[i].item() table, cache_len = manager.allocate_slots(i, cur_new_len) cache_lens_list.append(cache_len) all_block_tables.append(table) # Simulated KV insertion k_new = torch.randn(cur_new_len, num_kv_heads, head_size) v_new = torch.randn(cur_new_len, num_kv_heads, head_size) q_val = torch.randn(cur_new_len, num_heads, head_size) q_packed_tensors[cur_offset : cur_offset + cur_new_len] = q_val cur_offset = cur_offset + cur_new_len history_len = cache_len - cur_new_len for t in range(cur_new_len): logical_pos = history_len + t b_id = table[logical_pos // block_size] off = logical_pos % block_size k_cache.torch_tensor()[b_id, :, off, :] = k_new[t] v_cache.torch_tensor()[b_id, :, off, :] = v_new[t] offset_list.append(cur_offset) k_cache.actual_tensor().copy_(k_cache._torch_tensor) v_cache.actual_tensor().copy_(v_cache._torch_tensor) # 2. Wrap tensors for Infiniop q_new = TestTensor.from_torch(q_packed_tensors, dtype, device) out = TestTensor.from_torch(q_packed_tensors, dtype, device) out.actual_tensor().zero_() cache_lens = TestTensor.from_torch( torch.tensor(cache_lens_list, dtype=torch.int64), InfiniDtype.I64, device ) seq_lens = TestTensor.from_torch(seq_lens_cpu, InfiniDtype.I64, device) offset = TestTensor.from_torch( torch.tensor(offset_list, dtype=torch.int64), InfiniDtype.I64, device ) max_blocks = max(len(t) for t in all_block_tables) padded_tables = [t + [0] * (max_blocks - len(t)) for t in all_block_tables] block_tables = TestTensor.from_torch( torch.tensor(padded_tables, dtype=torch.int64), InfiniDtype.I64, device ) # 3. Reference Calculation def torch_paged_attention_multi_turn(): return ref_paged_attention_multi_turn( q_new.torch_tensor(), k_cache.torch_tensor(), v_cache.torch_tensor(), block_tables.torch_tensor(), cache_lens.torch_tensor(), seq_lens.torch_tensor(), offset.torch_tensor(), scale, ) ans = torch_paged_attention_multi_turn() # 4. Infiniop Operator Execution descriptor = infiniopOperatorDescriptor_t() check_error( LIBINFINIOP.infiniopCreatePagedAttentionPrefillDescriptor( handle, ctypes.byref(descriptor), out.descriptor, q_new.descriptor, k_cache.descriptor, v_cache.descriptor, block_tables.descriptor, cache_lens.descriptor, seq_lens.descriptor, offset.descriptor, None, # alibi_slopes_desc scale, ) ) workspace_size = c_uint64(0) check_error( LIBINFINIOP.infiniopGetPagedAttentionPrefillWorkspaceSize( descriptor, ctypes.byref(workspace_size) ) ) workspace = TestWorkspace(workspace_size.value, device) def lib_attn(): check_error( LIBINFINIOP.infiniopPagedAttentionPrefill( descriptor, workspace.data(), workspace_size.value, out.data(), q_new.data(), k_cache.data(), v_cache.data(), block_tables.data(), cache_lens.data(), seq_lens.data(), offset.data(), None, None, ) ) lib_attn() if sync: sync() # 5. Validation atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) if DEBUG: debug(out.actual_tensor(), ans, atol=atol, rtol=rtol) assert torch.allclose(out.actual_tensor(), ans, atol=atol, rtol=rtol) # Profiling if PROFILE: profile_operation( f"Torch_R{r}", lambda: torch_paged_attention_multi_turn(), device, NUM_PRERUN, NUM_ITERATIONS, ) profile_operation( f" Lib_R{r}", lambda: lib_attn(), device, NUM_PRERUN, NUM_ITERATIONS ) check_error( LIBINFINIOP.infiniopDestroyPagedAttentionPrefillDescriptor(descriptor) ) # ============================================================================== # Main Execution # ============================================================================== if __name__ == "__main__": args = get_args() DEBUG = args.debug PROFILE = args.profile NUM_PRERUN = args.num_prerun NUM_ITERATIONS = args.num_iterations for device in get_test_devices(args): test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) print("\033[92mTest passed!\033[0m")