Commit 0f90515c authored by wooway777's avatar wooway777
Browse files

issue/1065 - fix mha kv cache interface

parent 456ee3e1
......@@ -36,7 +36,7 @@ void run(void *planned_meta) {
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
auto out = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->out));
auto out_tensor = infinicore::adaptor::to_aten_tensor(p->out);
auto q = infinicore::adaptor::to_aten_tensor(p->q);
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache);
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache);
......@@ -46,7 +46,6 @@ void run(void *planned_meta) {
? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes))
: std::nullopt;
// No new KV tokens to append (pure decode, KV already written to cache).
std::optional<const at::Tensor> k_new = std::nullopt;
std::optional<const at::Tensor> v_new = std::nullopt;
std::optional<const at::Tensor> rotary_cos = std::nullopt;
......@@ -54,7 +53,14 @@ void run(void *planned_meta) {
std::optional<const at::Tensor> cache_batch_idx = std::nullopt;
std::optional<const at::Tensor> leftpad_k = std::nullopt;
flash::mha_fwd_kvcache(
const bool use_dynamic_out = q.dim() == 4 && k_cache.dim() == 4
&& q.size(1) == 1 && q.size(2) > k_cache.size(2)
&& q.size(3) % 8 == 0 && !alibi_slopes.has_value();
auto out = use_dynamic_out ? std::optional<at::Tensor>(std::nullopt)
: std::optional<at::Tensor>(out_tensor);
auto result = flash::mha_fwd_kvcache(
q,
k_cache,
v_cache,
......@@ -69,13 +75,16 @@ void run(void *planned_meta) {
alibi_slopes,
out,
p->scale,
true, // is_causal
-1, // window_size_left (-1 = no sliding window)
-1, // window_size_right
0.0f, // softcap
false, // is_rotary_interleaved
0 // num_splits (0 = auto)
);
true,
-1,
-1,
0.0f,
false,
0);
if (use_dynamic_out) {
out_tensor.copy_(result[0]);
}
#else
throw std::runtime_error("FlashAttention is not enabled in this build");
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment