Commit 38078981 authored by pengcheng888's avatar pengcheng888 Committed by PanZezhong
Browse files

issue/847-paged caching和atention添加infinicore的接口和测试

parent 298feac2
......@@ -5,6 +5,8 @@
#include "ops/causal_softmax.hpp"
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_caching.hpp"
#include "ops/random_sample.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
......
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
#include <optional>
namespace infinicore::op {
class PagedAttention {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale);
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale);
} // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class PagedCaching {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor);
static void execute(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping);
static common::OpDispatcher<schema> &dispatcher();
};
void paged_caching_(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping);
} // namespace infinicore::op
......@@ -44,6 +44,8 @@ from infinicore.ops.attention import attention
from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul
from infinicore.ops.narrow import narrow
from infinicore.ops.paged_attention import paged_attention
from infinicore.ops.paged_caching import paged_caching
from infinicore.ops.rearrange import rearrange
from infinicore.ops.squeeze import squeeze
from infinicore.ops.unsqueeze import unsqueeze
......@@ -115,6 +117,8 @@ __all__ = [
"from_list",
"from_numpy",
"from_torch",
"paged_caching",
"paged_attention",
"ones",
"strided_empty",
"strided_from_blob",
......
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
def paged_attention(
q: Tensor,
k_cache: Tensor,
v_cache: Tensor,
block_tables: Tensor,
seq_lens: Tensor,
alibi_slopes: Tensor | None = None,
scale: float = 1.0,
*,
out: Tensor | None = None,
):
if out is None:
return Tensor(
_infinicore.paged_attention(
q._underlying,
k_cache._underlying,
v_cache._underlying,
block_tables._underlying,
seq_lens._underlying,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)
)
_infinicore.paged_attention_(
out._underlying,
q._underlying,
k_cache._underlying,
v_cache._underlying,
block_tables._underlying,
seq_lens._underlying,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)
return out
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
def paged_caching(
k: Tensor,
v: Tensor,
k_cache: Tensor,
v_cache: Tensor,
slot_mapping: Tensor,
):
Tensor(
_infinicore.paged_caching_(
k._underlying,
v._underlying,
k_cache._underlying,
v_cache._underlying,
slot_mapping._underlying,
)
)
return (k_cache, v_cache)
#include "infinicore/ops/paged_attention.hpp"
#include "../../utils.hpp"
namespace infinicore::op {
common::OpDispatcher<PagedAttention::schema> &PagedAttention::dispatcher() {
static common::OpDispatcher<PagedAttention::schema> dispatcher_;
return dispatcher_;
};
void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, seq_lens);
infinicore::context::setDevice(out->device());
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale);
}
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
paged_attention_(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale);
return out;
}
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
PagedAttention::execute(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/paged_attention.hpp"
#include <infiniop.h>
namespace infinicore::op::paged_attention_impl::infiniop {
thread_local common::OpCache<size_t, infiniopPagedAttentionDescriptor_t> caches(
100, // capacity
[](infiniopPagedAttentionDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyPagedAttentionDescriptor(desc));
desc = nullptr;
}
});
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, seq_lens);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(seed);
infiniopPagedAttentionDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionDescriptor(
context::getInfiniopHandle(device), &desc,
out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), seq_lens->desc(),
alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr,
scale));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetPagedAttentionWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopPagedAttention(
desc, workspace->data(), workspace_size,
out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), seq_lens->data(),
alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr,
context::getStream()));
}
static bool registered = []() {
PagedAttention::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::paged_attention_impl::infiniop
#include "infinicore/ops/paged_caching.hpp"
#include "../../utils.hpp"
namespace infinicore::op {
common::OpDispatcher<PagedCaching::schema> &PagedCaching::dispatcher() {
static common::OpDispatcher<PagedCaching::schema> dispatcher_;
return dispatcher_;
};
void PagedCaching::execute(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k, v, k_cache, v_cache, slot_mapping);
infinicore::context::setDevice(k->device());
dispatcher().lookup(k->device().getType())(k, v, k_cache, v_cache, slot_mapping);
}
void paged_caching_(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping) {
PagedCaching::execute(k, v, k_cache, v_cache, slot_mapping);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/paged_caching.hpp"
#include <infiniop.h>
namespace infinicore::op::paged_caching_impl::infiniop {
thread_local common::OpCache<size_t, infiniopPagedCachingDescriptor_t> caches(
100, // capacity
[](infiniopPagedCachingDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyPagedCachingDescriptor(desc));
desc = nullptr;
}
});
void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping) {
size_t seed = hash_combine(k, v, k_cache, v_cache, slot_mapping);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(seed);
infiniopPagedCachingDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreatePagedCachingDescriptor(
context::getInfiniopHandle(device), &desc,
k->desc(), v->desc(), k_cache->desc(), v_cache->desc(), slot_mapping->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetPagedCachingWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopPagedCaching(
desc, workspace->data(), workspace_size,
k->data(), v->data(), k_cache->data(), v_cache->data(), slot_mapping->data(), context::getStream()));
}
static bool registered = []() {
PagedCaching::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::paged_caching_impl::infiniop
......@@ -9,6 +9,8 @@
#include "ops/linear.hpp"
#include "ops/matmul.hpp"
#include "ops/mul.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_caching.hpp"
#include "ops/random_sample.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
......@@ -28,6 +30,8 @@ inline void bind(py::module &m) {
bind_linear(m);
bind_matmul(m);
bind_mul(m);
bind_paged_attention(m);
bind_paged_caching(m);
bind_rearrange(m);
bind_rms_norm(m);
bind_silu(m);
......
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/paged_attention.hpp"
namespace py = pybind11;
namespace infinicore::ops {
Tensor py_paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, pybind11::object alibi_slopes, float scale) {
std::optional<Tensor> alibi_slopes_tensor = std::nullopt;
if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
}
return op::paged_attention(q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes_tensor, scale);
}
void py_paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, pybind11::object alibi_slopes, float scale) {
std::optional<Tensor> alibi_slopes_tensor = std::nullopt;
if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
}
op::paged_attention_(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes_tensor, scale);
}
inline void bind_paged_attention(py::module &m) {
m.def("paged_attention",
&ops::py_paged_attention,
py::arg("q"),
py::arg("k_cache"),
py::arg("v_cache"),
py::arg("block_tables"),
py::arg("seq_lens"),
py::arg("alibi_slopes"),
py::arg("scale"),
R"doc(Paged attention of query and key cache tensors.)doc");
m.def("paged_attention_",
&ops::py_paged_attention_,
py::arg("out"),
py::arg("q"),
py::arg("k_cache"),
py::arg("v_cache"),
py::arg("block_tables"),
py::arg("seq_lens"),
py::arg("alibi_slopes"),
py::arg("scale"),
R"doc(In-place paged attention of query and key cache tensors.)doc");
}
} // namespace infinicore::ops
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/paged_caching.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_paged_caching(py::module &m) {
m.def("paged_caching_",
&op::paged_caching_,
py::arg("k"),
py::arg("v"),
py::arg("k_cache"),
py::arg("v_cache"),
py::arg("slot_mapping"),
R"doc(Paged caching of key and value tensors.)doc");
}
} // namespace infinicore::ops
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
TensorInitializer,
)
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format:
_TEST_CASES_DATA = [
# (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_seq_len, use_alibi)
(1, 1, 1, 128, 16, 15, False),
# (4, 40, 40, 128, 16, 1024, False),
# (6, 40, 40, 128, 16, 1024, False),
# (3, 8, 8, 128, 16, 1024, False),
# (8, 64, 8, 128, 16, 2048, False),
]
# Tolerance configuration
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-4, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
}
# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# ==============================================================================
# Reference Implementation
# ==============================================================================
def parse_test_cases():
"""
Parse test case data and return list of TestCase objects for paged_attention operation.
Each test case contains all necessary information for execution and validation.
"""
test_cases = []
for (
num_seqs,
num_heads,
num_kv_heads,
head_size,
block_size,
max_seq_len,
use_alibi,
) in _TEST_CASES_DATA:
scale = 1.0 / (head_size**0.5)
max_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
num_blocks = num_seqs * max_blocks_per_seq # A reasonable number for testing
seq_lens_torch = torch.randint(1, 1024, (num_seqs,), dtype=torch.int32)
# seq_lens_torch = torch.ones(
# (num_seqs,), dtype=torch.int32
# )
block_tables = torch.arange(
0, num_seqs * max_blocks_per_seq, dtype=torch.int32
).view(num_seqs, max_blocks_per_seq)
print("block_tables.shape", block_tables.shape, block_tables)
q_shape = (num_seqs, num_heads, head_size)
k_cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
v_cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
block_tables_shape = block_tables.shape
seq_lens_shape = seq_lens_torch.shape
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
# Create typed tensor specs
q_spec = TensorSpec.from_tensor(q_shape, None, dtype)
k_cache_spec = TensorSpec.from_tensor(k_cache_shape, None, dtype)
v_cache_spec = TensorSpec.from_tensor(v_cache_shape, None, dtype)
block_tables_spec = TensorSpec.from_tensor(
block_tables_shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=block_tables,
dtype=infinicore.int32,
)
seq_lens_spec = TensorSpec.from_tensor(
seq_lens_shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=seq_lens_torch,
dtype=infinicore.int32,
)
# Paged attention operation: returns output tensor
out_shape = (num_seqs, num_heads, head_size)
out_spec = TensorSpec.from_tensor(out_shape, None, dtype)
test_cases.append(
TestCase(
inputs=[
q_spec,
k_cache_spec,
v_cache_spec,
block_tables_spec,
seq_lens_spec,
],
kwargs={"alibi_slopes": None, "scale": scale},
output_spec=None,
comparison_target=0,
tolerance=tolerance,
description=f"PagedAttention",
)
)
return test_cases
def ref_masked_attention(query, key, value, scale, attn_mask=None):
# Reference implementation for a single masked attention head.
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
if attn_mask is not None:
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out
def ref_single_query_cached_kv_attention(
query, key_cache, value_cache, block_tables, seq_lens, alibi_slopes, scale
):
# Reference implementation for paged attention, iterating through each sequence.
output = torch.empty_like(query)
num_query_heads, num_kv_heads = query.shape[1], value_cache.shape[1]
num_queries_per_kv = num_query_heads // num_kv_heads
head_size, block_size = value_cache.shape[3], value_cache.shape[2]
num_seqs = query.shape[0]
for i in range(num_seqs):
q = query[i].unsqueeze(0)
seq_len = seq_lens[i].item()
block_table = block_tables[i]
keys_lst, values_lst = [], []
for j in range(seq_len):
block_num = block_table[j // block_size].item()
block_off = j % block_size
k = key_cache[block_num, :, block_off, :]
v = value_cache[block_num, :, block_off, :]
keys_lst.append(k)
values_lst.append(v)
keys = torch.stack(keys_lst, dim=0)
values = torch.stack(values_lst, dim=0)
if num_queries_per_kv > 1:
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
alibi_bias = None
if alibi_slopes is not None:
pos = torch.arange(seq_len, device=query.device).int()
alibi_bias = (pos - seq_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1)
out = ref_masked_attention(q, keys, values, scale, alibi_bias)
output[i] = out.view(num_query_heads, head_size)
return output
class OpTest(BaseOperatorTest):
"""PagedAttention operator test with simplified implementation"""
def __init__(self):
super().__init__("PagedAttention")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, *args, **kwargs):
"""PyTorch paged_caching implementation"""
return ref_single_query_cached_kv_attention(*args, **kwargs)
def infinicore_operator(self, *args, **kwargs):
"""InfiniCore paged_attention implementation"""
out = infinicore.paged_attention(*args, **kwargs)
infinicore.sync_stream()
return out
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
TensorInitializer,
)
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (num_seqs, max_seq_len, num_kv_heads, head_size, block_size)
_TEST_CASES_DATA = [
(1, 128, 8, 128, 16),
(5, 512, 40, 128, 16),
(16, 1024, 8, 64, 32),
(10, 1024, 40, 64, 32),
]
# Tolerance configuration
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-4, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
}
# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# ==============================================================================
# Reference Implementation
# ==============================================================================
def ref_paged_caching(key, value, key_cache_pool, value_cache_pool, slot_mapping):
"""
Reference implementation for paged_caching operator.
Args:
key (torch.Tensor): Keys, shape [ntok, nkvh, dh]
value (torch.Tensor): Values, shape [ntok, nkvh, dh]
key_cache_pool (torch.Tensor): K cache pool, shape [num_blocks, nkvh, block_size, dh]
value_cache_pool (torch.Tensor): V cache pool, shape [num_blocks, nkvh, block_size, dh]
slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
"""
ntok = key.shape[0]
block_size = key_cache_pool.shape[2]
# This reference implementation operates on a cloned cache to avoid modifying the original input tensor,
# mimicking the behavior where the custom operator writes to its output tensor.
k_cache_ref = key_cache_pool.clone()
v_cache_ref = value_cache_pool.clone()
for i in range(ntok):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
key_token = key[i]
value_token = value[i]
k_cache_ref[block_idx, :, block_offset, :] = key_token
v_cache_ref[block_idx, :, block_offset, :] = value_token
return k_cache_ref, v_cache_ref
def parse_test_cases():
"""
Parse test case data and return list of TestCase objects for paged_caching operation.
Each test case contains all necessary information for execution and validation.
"""
test_cases = []
for num_seqs, max_seq_len, num_kv_heads, head_size, block_size in _TEST_CASES_DATA:
num_blocks = 4096 # A reasonably large cache pool for testing
# Create metadata: variable context lengths for each sequence in the batch
context_lens_torch = torch.randint(
1, max_seq_len + 1, (num_seqs,), dtype=torch.int32
)
ntok = torch.sum(context_lens_torch).item()
# Simulate the scheduler's behavior to create the slot_mapping
slot_mapping_list = []
current_slot = 0
for length in context_lens_torch:
# Find a contiguous chunk of 'length' slots
start_slot = current_slot
slot_mapping_list.extend(range(start_slot, start_slot + length.item()))
current_slot += length.item()
# Ensure we don't exceed the total number of slots in the cache
assert current_slot <= num_blocks * block_size, (
"Not enough blocks in the cache pool for this test case"
)
slot_mapping = torch.tensor(slot_mapping_list, dtype=torch.int32)
# print("slot_mapping", slot_mapping)
slot_mapping_shape = slot_mapping.shape
k_shape = (ntok, num_kv_heads, head_size)
v_shape = (ntok, num_kv_heads, head_size)
k_cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
v_cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
# Create typed tensor specs
k_spec = TensorSpec.from_tensor(k_shape, None, dtype)
v_spec = TensorSpec.from_tensor(v_shape, None, dtype)
k_cache_spec = TensorSpec.from_tensor(k_cache_shape, None, dtype)
v_cache_spec = TensorSpec.from_tensor(v_cache_shape, None, dtype)
slot_mapping_spec = TensorSpec.from_tensor(
slot_mapping_shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=slot_mapping,
dtype=infinicore.int32,
)
# In-place operation: modifies k_cache (index 2) and v_cache (index 3)
test_cases.append(
TestCase(
inputs=[
k_spec,
v_spec,
k_cache_spec,
v_cache_spec,
slot_mapping_spec,
],
kwargs=None,
output_spec=None,
comparison_target=0, # Only compare k_cache
tolerance=tolerance,
description=f"PagedCaching",
)
)
return test_cases
class OpTest(BaseOperatorTest):
"""PagedCaching operator test with simplified implementation"""
def __init__(self):
super().__init__("PagedCaching")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, *args, **kwargs):
"""PyTorch paged_caching implementation"""
return ref_paged_caching(*args, **kwargs)
def infinicore_operator(self, *args, **kwargs):
"""InfiniCore paged_caching implementation"""
return infinicore.paged_caching(*args, **kwargs)
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
......@@ -148,10 +148,8 @@ def test(
(num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
)
seq_lens_direct = 1023
seq_lens_torch = torch.randint(
1, seq_lens_direct + 1, (num_seqs,), dtype=torch.int64
)
seq_lens_torch = torch.randint(1, 1024, (num_seqs,), dtype=torch.int64)
seq_lens = TestTensor.from_torch(seq_lens_torch, InfiniDtype.I64, device)
block_tables_py = torch.arange(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment