paged_caching.py 7.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
    LIBINFINIOP,
    TestTensor,
    get_test_devices,
    check_error,
    test_operator,
    get_args,
    debug,
    get_tolerance,
    profile_operation,
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
    TestWorkspace,
)


# ==============================================================================
#  Reference Implementation
# ==============================================================================
25
def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping):
26
27
28
29
30
31
    """
    Reference implementation for paged_caching operator.

    Args:
        key_cache_pool (torch.Tensor): K cache pool, shape [num_blocks, nkvh, block_size, dh]
        value_cache_pool (torch.Tensor): V cache pool, shape [num_blocks, nkvh, block_size, dh]
32
33
        key (torch.Tensor): Keys, shape [ntok, nkvh, dh]
        value (torch.Tensor): Values, shape [ntok, nkvh, dh]
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
    """
    ntok = key.shape[0]
    block_size = key_cache_pool.shape[2]

    # This reference implementation operates on a cloned cache to avoid modifying the original input tensor,
    # mimicking the behavior where the custom operator writes to its output tensor.
    k_cache_ref = key_cache_pool.clone()
    v_cache_ref = value_cache_pool.clone()

    for i in range(ntok):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size

        key_token = key[i]
        value_token = value[i]

        k_cache_ref[block_idx, :, block_offset, :] = key_token
        v_cache_ref[block_idx, :, block_offset, :] = value_token

    return k_cache_ref, v_cache_ref


# ==============================================================================
#  Test Configuration (Internal Use Only)
# ==============================================================================
_TEST_CASES_ = [
    # (num_seqs, max_seq_len, num_kv_heads, head_size, block_size)
    (1, 128, 8, 128, 16),
    (5, 512, 40, 128, 16),
    (16, 1024, 8, 64, 32),
    (10, 1024, 40, 64, 32),
]

# Data types for testing
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32]

# Tolerance map for different data types
_TOLERANCE_MAP = {
74
75
76
    InfiniDtype.F16: {"atol": 0, "rtol": 1e-5},
    InfiniDtype.BF16: {"atol": 0, "rtol": 1e-5},
    InfiniDtype.F32: {"atol": 0, "rtol": 1e-5},
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
}

# Global flags for controlling test behavior
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 100


def test(
    handle,
    device,
    num_seqs,  # nreq
    max_seq_len,
    num_kv_heads,  # nkvh
    head_size,  # dh
    block_size,
    dtype=InfiniDtype.F16,
    sync=None,
):
    print(
        f"Testing PagedCaching on {InfiniDeviceNames[device]} with "
        f"num_seqs={num_seqs}, max_seq_len={max_seq_len}, num_kv_heads={num_kv_heads}, "
        f"head_size={head_size}, block_size={block_size}, dtype={InfiniDtypeNames[dtype]}"
    )

    num_blocks = 4096  # A reasonably large cache pool for testing

    # Create metadata: variable context lengths for each sequence in the batch
    context_lens_torch = torch.randint(
        1, max_seq_len + 1, (num_seqs,), dtype=torch.int64
    )
    ntok = torch.sum(context_lens_torch).item()

    # If ntok is 0 (all sequences have length 0), skip the test
    if ntok == 0:
        print("Skipping test case with ntok=0")
        return

    # Simulate the scheduler's behavior to create the slot_mapping
    slot_mapping_list = []
    current_slot = 0
    for length in context_lens_torch:
        # Find a contiguous chunk of 'length' slots
        start_slot = current_slot
        slot_mapping_list.extend(range(start_slot, start_slot + length.item()))
        current_slot += length.item()

    # Ensure we don't exceed the total number of slots in the cache
126
127
128
    assert current_slot <= num_blocks * block_size, (
        "Not enough blocks in the cache pool for this test case"
    )
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

    slot_mapping_torch = torch.tensor(slot_mapping_list, dtype=torch.int64)

    # Create input tensors based on the calculated total tokens (ntok)
    k = TestTensor((ntok, num_kv_heads, head_size), None, dtype, device)
    v = TestTensor((ntok, num_kv_heads, head_size), None, dtype, device)
    slot_mapping = TestTensor.from_torch(slot_mapping_torch, InfiniDtype.I64, device)

    # The cache pools are the "output" tensors for this operator
    k_cache_pool = TestTensor(
        (num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
    )
    v_cache_pool = TestTensor(
        (num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
    )

    # Run reference implementation
    k_cache_ref, v_cache_ref = ref_paged_caching(
        k_cache_pool.torch_tensor(),
        v_cache_pool.torch_tensor(),
149
150
        k.torch_tensor(),
        v.torch_tensor(),
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        slot_mapping.torch_tensor(),
    )

    if sync:
        sync()

    # Create operator descriptor
    descriptor = infiniopOperatorDescriptor_t()
    check_error(
        LIBINFINIOP.infiniopCreatePagedCachingDescriptor(
            handle,
            ctypes.byref(descriptor),
            k_cache_pool.descriptor,
            v_cache_pool.descriptor,
165
166
            k.descriptor,
            v.descriptor,
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
            slot_mapping.descriptor,
        )
    )

    # Get workspace size (likely 0 for this operator, but good practice to include)
    workspace_size = c_uint64(0)
    check_error(
        LIBINFINIOP.infiniopGetPagedCachingWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
    )
    workspace = TestWorkspace(workspace_size.value, device)

    # Invalidate descriptors to ensure kernel does not rely on them
    k.destroy_desc()
    v.destroy_desc()
    k_cache_pool.destroy_desc()
    v_cache_pool.destroy_desc()
    slot_mapping.destroy_desc()

    # Define the library call as a lambda for profiling
    def lib_paged_caching():
        check_error(
            LIBINFINIOP.infiniopPagedCaching(
                descriptor,
                workspace.data(),
                workspace_size.value,
                k_cache_pool.data(),
                v_cache_pool.data(),
196
197
                k.data(),
                v.data(),
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
                slot_mapping.data(),
                None,
            )
        )

    # Execute the custom operator
    lib_paged_caching()

    if sync:
        sync()

    # Verify correctness
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
        print("Verifying K cache...")
        debug(k_cache_pool.actual_tensor(), k_cache_ref, atol=atol, rtol=rtol)
        print("Verifying V cache...")
        debug(v_cache_pool.actual_tensor(), v_cache_ref, atol=atol, rtol=rtol)

    assert torch.allclose(
        k_cache_pool.actual_tensor(), k_cache_ref, atol=atol, rtol=rtol
    )
    assert torch.allclose(
        v_cache_pool.actual_tensor(), v_cache_ref, atol=atol, rtol=rtol
    )

    # Profiling workflow
    if PROFILE:
        # fmt: off
        profile_operation("PyTorch", lambda: ref_paged_caching(
            k.torch_tensor(), v.torch_tensor(), 
            k_cache_pool.torch_tensor(), v_cache_pool.torch_tensor(), 
            slot_mapping.torch_tensor()), 
            device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lib_paged_caching, device, NUM_PRERUN, NUM_ITERATIONS)
        # fmt: on

    # Clean up resources
    check_error(LIBINFINIOP.infiniopDestroyPagedCachingDescriptor(descriptor))


if __name__ == "__main__":
    args = get_args()

    # Configure testing options from command line arguments
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    for device in get_test_devices(args):
        test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES)

    print("\033[92mTest passed!\033[0m")