Commit 665f383b authored by suss's avatar suss Committed by wooway777
Browse files

issue/1065 - add mha_kvcache

parent 21c6af2d
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include <optional>
namespace infinicore::op {
// Flash Attention KV-cache decode op.
//
// Wraps flash::mha_fwd_kvcache for single-step (decode) attention over a
// paged KV cache.
//
// Tensor shapes:
// out : [batch_size, seqlen_q, num_heads, head_size]
// q : [batch_size, seqlen_q, num_heads, head_size]
// k_cache : [num_blocks, block_size, num_heads_k, head_size] (paged layout)
// v_cache : [num_blocks, block_size, num_heads_k, head_size] (paged layout)
// seqlens_k : [batch_size] int32 — total KV length per request
// block_table : [batch_size, max_num_blocks_per_seq] int32
INFINICORE_GRAPH_OP_CLASS(
MhaKVCache,
Tensor, // out
const Tensor &, // q
const Tensor &, // k_cache
const Tensor &, // v_cache
const Tensor &, // seqlens_k
const Tensor &, // block_table
std::optional<Tensor>, // alibi_slopes
float); // scale
Tensor mha_kvcache(const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale);
void mha_kvcache_(Tensor out,
const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale);
} // namespace infinicore::op
#include "infinicore/ops/mha_kvcache.hpp"
#include "../../utils.hpp"
namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(MhaKVCache);
MhaKVCache::MhaKVCache(Tensor out,
const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, seqlens_k, block_table);
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(),
out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale);
}
void MhaKVCache::execute(Tensor out,
const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(
MhaKVCache,
out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale);
}
void mha_kvcache_(Tensor out,
const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale) {
MhaKVCache::execute(out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale);
}
Tensor mha_kvcache(const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale) {
// Output shape matches q: [batch_size, seqlen_q, num_heads, head_size]
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
mha_kvcache_(out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale);
return out;
}
} // namespace infinicore::op
#include "infinicore/ops/mha_kvcache.hpp"
#include "infinicore/adaptor/flash_attention_adaptor.hpp"
namespace infinicore::op::mha_kvcache_impl::flashattn {
struct PlannedMeta {
graph::GraphTensor out, q, k_cache, v_cache, seqlens_k, block_table;
std::optional<graph::GraphTensor> alibi_slopes;
float scale;
};
void *plan(Tensor out,
const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale) {
return new PlannedMeta{
graph::GraphTensor(out),
graph::GraphTensor(q),
graph::GraphTensor(k_cache),
graph::GraphTensor(v_cache),
graph::GraphTensor(seqlens_k),
graph::GraphTensor(block_table),
alibi_slopes ? std::optional<graph::GraphTensor>(graph::GraphTensor(*alibi_slopes)) : std::nullopt,
scale};
}
void run(void *planned_meta) {
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
auto out = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->out));
auto q = infinicore::adaptor::to_aten_tensor(p->q);
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache);
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache);
auto seqlens_k = std::optional<const at::Tensor>(infinicore::adaptor::to_aten_tensor(p->seqlens_k));
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
auto alibi_slopes = p->alibi_slopes
? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes))
: std::nullopt;
// No new KV tokens to append (pure decode, KV already written to cache).
std::optional<const at::Tensor> k_new = std::nullopt;
std::optional<const at::Tensor> v_new = std::nullopt;
std::optional<const at::Tensor> rotary_cos = std::nullopt;
std::optional<const at::Tensor> rotary_sin = std::nullopt;
std::optional<const at::Tensor> cache_batch_idx = std::nullopt;
std::optional<const at::Tensor> leftpad_k = std::nullopt;
flash::mha_fwd_kvcache(
q,
k_cache,
v_cache,
k_new,
v_new,
seqlens_k,
rotary_cos,
rotary_sin,
cache_batch_idx,
leftpad_k,
block_table,
alibi_slopes,
out,
p->scale,
true, // is_causal
-1, // window_size_left (-1 = no sliding window)
-1, // window_size_right
0.0f, // softcap
false, // is_rotary_interleaved
0 // num_splits (0 = auto)
);
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(MhaKVCache, &plan, &run, &cleanup);
} // namespace infinicore::op::mha_kvcache_impl::flashattn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment