Unverified Commit 582eb8fc authored by Jongseok Park's avatar Jongseok Park Committed by GitHub
Browse files

Fix params.seqlen_k reference in the splitkv kernel to binfo.actual_seqlen_k (#18)

parent f9d2c100
...@@ -531,7 +531,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -531,7 +531,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
if (m_block * kBlockM >= binfo.actual_seqlen_q) return; if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; const int n_blocks_per_split = ((binfo.actual_seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
const int n_block_min = !Is_local const int n_block_min = !Is_local
? n_split_idx * n_blocks_per_split ? n_split_idx * n_blocks_per_split
: std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment