Commit e13ad8f9 authored by PanZezhong's avatar PanZezhong
Browse files

issue/847 correct cache_lens naming

parent 31c0af3f
......@@ -9,10 +9,10 @@ namespace infinicore::op {
class PagedAttention {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float);
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale);
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale);
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
} // namespace infinicore::op
......@@ -7,7 +7,7 @@ def paged_attention(
k_cache: Tensor,
v_cache: Tensor,
block_tables: Tensor,
seq_lens: Tensor,
cache_lens: Tensor,
alibi_slopes: Tensor | None = None,
scale: float = 1.0,
*,
......@@ -20,7 +20,7 @@ def paged_attention(
k_cache._underlying,
v_cache._underlying,
block_tables._underlying,
seq_lens._underlying,
cache_lens._underlying,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)
......@@ -32,7 +32,7 @@ def paged_attention(
k_cache._underlying,
v_cache._underlying,
block_tables._underlying,
seq_lens._underlying,
cache_lens._underlying,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)
......
......@@ -9,20 +9,20 @@ common::OpDispatcher<PagedAttention::schema> &PagedAttention::dispatcher() {
return dispatcher_;
};
void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, seq_lens);
void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, cache_lens);
infinicore::context::setDevice(out->device());
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale);
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
}
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
paged_attention_(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale);
paged_attention_(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
return out;
}
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
PagedAttention::execute(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale);
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
PagedAttention::execute(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
}
} // namespace infinicore::op
......@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionDescriptor_t> caches(
}
});
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, seq_lens);
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, cache_lens);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
......@@ -27,7 +27,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionDescriptor(
context::getInfiniopHandle(device), &desc,
out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), seq_lens->desc(),
out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), cache_lens->desc(),
alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr,
scale));
cache.put(seed, desc);
......@@ -41,7 +41,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
INFINICORE_CHECK_ERROR(infiniopPagedAttention(
desc, workspace->data(), workspace_size,
out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), seq_lens->data(),
out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), cache_lens->data(),
alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr,
context::getStream()));
}
......
......@@ -8,21 +8,21 @@ namespace py = pybind11;
namespace infinicore::ops {
Tensor py_paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, pybind11::object alibi_slopes, float scale) {
Tensor py_paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, pybind11::object alibi_slopes, float scale) {
std::optional<Tensor> alibi_slopes_tensor = std::nullopt;
if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
}
return op::paged_attention(q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes_tensor, scale);
return op::paged_attention(q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes_tensor, scale);
}
void py_paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, pybind11::object alibi_slopes, float scale) {
void py_paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, pybind11::object alibi_slopes, float scale) {
std::optional<Tensor> alibi_slopes_tensor = std::nullopt;
if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
}
op::paged_attention_(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes_tensor, scale);
op::paged_attention_(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes_tensor, scale);
}
inline void bind_paged_attention(py::module &m) {
......@@ -32,7 +32,7 @@ inline void bind_paged_attention(py::module &m) {
py::arg("k_cache"),
py::arg("v_cache"),
py::arg("block_tables"),
py::arg("seq_lens"),
py::arg("cache_lens"),
py::arg("alibi_slopes"),
py::arg("scale"),
R"doc(Paged attention of query and key cache tensors.)doc");
......@@ -44,7 +44,7 @@ inline void bind_paged_attention(py::module &m) {
py::arg("k_cache"),
py::arg("v_cache"),
py::arg("block_tables"),
py::arg("seq_lens"),
py::arg("cache_lens"),
py::arg("alibi_slopes"),
py::arg("scale"),
R"doc(In-place paged attention of query and key cache tensors.)doc");
......
......@@ -62,7 +62,7 @@ def parse_test_cases():
max_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
num_blocks = num_seqs * max_blocks_per_seq # A reasonable number for testing
seq_lens_torch = torch.randint(1, max_seq_len, (num_seqs,), dtype=torch.int64)
cache_lens_torch = torch.randint(1, max_seq_len, (num_seqs,), dtype=torch.int64)
block_tables = torch.arange(
0, num_seqs * max_blocks_per_seq, dtype=torch.int64
......@@ -75,7 +75,7 @@ def parse_test_cases():
v_cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
block_tables_shape = block_tables.shape
seq_lens_shape = seq_lens_torch.shape
cache_lens_shape = cache_lens_torch.shape
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
......@@ -91,10 +91,10 @@ def parse_test_cases():
set_tensor=block_tables,
dtype=infinicore.int64,
)
seq_lens_spec = TensorSpec.from_tensor(
seq_lens_shape,
cache_lens_spec = TensorSpec.from_tensor(
cache_lens_shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=seq_lens_torch,
set_tensor=cache_lens_torch,
dtype=infinicore.int64,
)
......@@ -108,7 +108,7 @@ def parse_test_cases():
k_cache_spec,
v_cache_spec,
block_tables_spec,
seq_lens_spec,
cache_lens_spec,
],
kwargs={"alibi_slopes": None, "scale": scale},
output_spec=None,
......@@ -132,7 +132,7 @@ def ref_masked_attention(query, key, value, scale, attn_mask=None):
def ref_single_query_cached_kv_attention(
query, key_cache, value_cache, block_tables, seq_lens, alibi_slopes, scale
query, key_cache, value_cache, block_tables, cache_lens, alibi_slopes, scale
):
# Reference implementation for paged attention, iterating through each sequence.
output = torch.empty_like(query)
......@@ -143,7 +143,7 @@ def ref_single_query_cached_kv_attention(
for i in range(num_seqs):
q = query[i].unsqueeze(0)
seq_len = seq_lens[i].item()
seq_len = cache_lens[i].item()
block_table = block_tables[i]
keys_lst, values_lst = [], []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment