Commit 0f90515c authored by wooway777's avatar wooway777
Browse files

issue/1065 - fix mha kv cache interface

parent 456ee3e1
...@@ -36,7 +36,7 @@ void run(void *planned_meta) { ...@@ -36,7 +36,7 @@ void run(void *planned_meta) {
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta); auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
auto out = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->out)); auto out_tensor = infinicore::adaptor::to_aten_tensor(p->out);
auto q = infinicore::adaptor::to_aten_tensor(p->q); auto q = infinicore::adaptor::to_aten_tensor(p->q);
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache); auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache);
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache); auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache);
...@@ -46,7 +46,6 @@ void run(void *planned_meta) { ...@@ -46,7 +46,6 @@ void run(void *planned_meta) {
? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) ? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes))
: std::nullopt; : std::nullopt;
// No new KV tokens to append (pure decode, KV already written to cache).
std::optional<const at::Tensor> k_new = std::nullopt; std::optional<const at::Tensor> k_new = std::nullopt;
std::optional<const at::Tensor> v_new = std::nullopt; std::optional<const at::Tensor> v_new = std::nullopt;
std::optional<const at::Tensor> rotary_cos = std::nullopt; std::optional<const at::Tensor> rotary_cos = std::nullopt;
...@@ -54,7 +53,14 @@ void run(void *planned_meta) { ...@@ -54,7 +53,14 @@ void run(void *planned_meta) {
std::optional<const at::Tensor> cache_batch_idx = std::nullopt; std::optional<const at::Tensor> cache_batch_idx = std::nullopt;
std::optional<const at::Tensor> leftpad_k = std::nullopt; std::optional<const at::Tensor> leftpad_k = std::nullopt;
flash::mha_fwd_kvcache( const bool use_dynamic_out = q.dim() == 4 && k_cache.dim() == 4
&& q.size(1) == 1 && q.size(2) > k_cache.size(2)
&& q.size(3) % 8 == 0 && !alibi_slopes.has_value();
auto out = use_dynamic_out ? std::optional<at::Tensor>(std::nullopt)
: std::optional<at::Tensor>(out_tensor);
auto result = flash::mha_fwd_kvcache(
q, q,
k_cache, k_cache,
v_cache, v_cache,
...@@ -69,13 +75,16 @@ void run(void *planned_meta) { ...@@ -69,13 +75,16 @@ void run(void *planned_meta) {
alibi_slopes, alibi_slopes,
out, out,
p->scale, p->scale,
true, // is_causal true,
-1, // window_size_left (-1 = no sliding window) -1,
-1, // window_size_right -1,
0.0f, // softcap 0.0f,
false, // is_rotary_interleaved false,
0 // num_splits (0 = auto) 0);
);
if (use_dynamic_out) {
out_tensor.copy_(result[0]);
}
#else #else
throw std::runtime_error("FlashAttention is not enabled in this build"); throw std::runtime_error("FlashAttention is not enabled in this build");
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment