"vscode:/vscode.git/clone" did not exist on "c53711bd63dba1945c45f654b9f8e0776b02f7f2"
Unverified Commit 7d3195ea authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Bugfix] Fix IMA in DSA + MTP (#40772)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent 512f5221
...@@ -599,6 +599,11 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel( ...@@ -599,6 +599,11 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel(
const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE; const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE;
// Find batch index within a block // Find batch index within a block
__shared__ int batch_idx[BLOCK_Y_SIZE]; __shared__ int batch_idx[BLOCK_Y_SIZE];
if (threadIdx.x == 0) {
batch_idx[threadIdx.y] = -1;
}
__syncthreads();
for (int iter = 0; iter < cuda_utils::ceil_div(batch_size, int(blockDim.x)); for (int iter = 0; iter < cuda_utils::ceil_div(batch_size, int(blockDim.x));
iter++) { iter++) {
int tid = iter * blockDim.x + threadIdx.x; int tid = iter * blockDim.x + threadIdx.x;
...@@ -611,16 +616,18 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel( ...@@ -611,16 +616,18 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel(
} }
} }
#ifndef USE_ROCM __syncthreads();
__syncwarp();
#endif
if (head_idx >= head_dim || token_idx >= num_tokens) { // num_tokens may be an allocation upper bound when Python avoids a D2H sync.
// Only tokens covered by the exact device-side cu_seq_lens are valid to
// gather.
const int batch = batch_idx[threadIdx.y];
if (head_idx >= head_dim || token_idx >= num_tokens || batch < 0) {
return; return;
} }
const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]]; const int inbatch_seq_idx = token_idx - cu_seq_lens[batch];
const int block_idx = block_table[batch_idx[threadIdx.y] * num_blocks + const int block_idx =
inbatch_seq_idx / cache_block_size]; block_table[batch * num_blocks + inbatch_seq_idx / cache_block_size];
const int64_t src_block_offset = block_idx * block_stride; const int64_t src_block_offset = block_idx * block_stride;
const int64_t cache_inblock_offset = const int64_t cache_inblock_offset =
(inbatch_seq_idx % cache_block_size) * head_dim + head_idx; (inbatch_seq_idx % cache_block_size) * head_dim + head_idx;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment