Unverified Commit 130d5fd8 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Fix a bug in attention kernel (#68)

parent e070829a
...@@ -345,7 +345,7 @@ void single_query_cached_kv_attention_launcher( ...@@ -345,7 +345,7 @@ void single_query_cached_kv_attention_launcher(
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_context_len * sizeof(T); int logits_size = padded_max_context_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
int shared_mem_size = std::max(logits_size, outputs_size); int shared_mem_size = std::max(logits_size, outputs_size);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment