Unverified Commit 7d3195ea authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Bugfix] Fix IMA in DSA + MTP (#40772)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent 512f5221
......@@ -599,6 +599,11 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel(
const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE;
// Find batch index within a block
__shared__ int batch_idx[BLOCK_Y_SIZE];
if (threadIdx.x == 0) {
batch_idx[threadIdx.y] = -1;
}
__syncthreads();
for (int iter = 0; iter < cuda_utils::ceil_div(batch_size, int(blockDim.x));
iter++) {
int tid = iter * blockDim.x + threadIdx.x;
......@@ -611,16 +616,18 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel(
}
}
#ifndef USE_ROCM
__syncwarp();
#endif
__syncthreads();
if (head_idx >= head_dim || token_idx >= num_tokens) {
// num_tokens may be an allocation upper bound when Python avoids a D2H sync.
// Only tokens covered by the exact device-side cu_seq_lens are valid to
// gather.
const int batch = batch_idx[threadIdx.y];
if (head_idx >= head_dim || token_idx >= num_tokens || batch < 0) {
return;
}
const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]];
const int block_idx = block_table[batch_idx[threadIdx.y] * num_blocks +
inbatch_seq_idx / cache_block_size];
const int inbatch_seq_idx = token_idx - cu_seq_lens[batch];
const int block_idx =
block_table[batch * num_blocks + inbatch_seq_idx / cache_block_size];
const int64_t src_block_offset = block_idx * block_stride;
const int64_t cache_inblock_offset =
(inbatch_seq_idx % cache_block_size) * head_dim + head_idx;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment