Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7d3195ea
Unverified
Commit
7d3195ea
authored
Apr 24, 2026
by
Woosuk Kwon
Committed by
GitHub
Apr 24, 2026
Browse files
[Bugfix] Fix IMA in DSA + MTP (#40772)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
512f5221
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
7 deletions
+14
-7
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+14
-7
No files found.
csrc/cache_kernels.cu
View file @
7d3195ea
...
@@ -599,6 +599,11 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel(
...
@@ -599,6 +599,11 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel(
const
int
head_idx
=
(
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
)
*
VEC_SIZE
;
const
int
head_idx
=
(
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
)
*
VEC_SIZE
;
// Find batch index within a block
// Find batch index within a block
__shared__
int
batch_idx
[
BLOCK_Y_SIZE
];
__shared__
int
batch_idx
[
BLOCK_Y_SIZE
];
if
(
threadIdx
.
x
==
0
)
{
batch_idx
[
threadIdx
.
y
]
=
-
1
;
}
__syncthreads
();
for
(
int
iter
=
0
;
iter
<
cuda_utils
::
ceil_div
(
batch_size
,
int
(
blockDim
.
x
));
for
(
int
iter
=
0
;
iter
<
cuda_utils
::
ceil_div
(
batch_size
,
int
(
blockDim
.
x
));
iter
++
)
{
iter
++
)
{
int
tid
=
iter
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
iter
*
blockDim
.
x
+
threadIdx
.
x
;
...
@@ -611,16 +616,18 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel(
...
@@ -611,16 +616,18 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel(
}
}
}
}
#ifndef USE_ROCM
__syncthreads
();
__syncwarp
();
#endif
if
(
head_idx
>=
head_dim
||
token_idx
>=
num_tokens
)
{
// num_tokens may be an allocation upper bound when Python avoids a D2H sync.
// Only tokens covered by the exact device-side cu_seq_lens are valid to
// gather.
const
int
batch
=
batch_idx
[
threadIdx
.
y
];
if
(
head_idx
>=
head_dim
||
token_idx
>=
num_tokens
||
batch
<
0
)
{
return
;
return
;
}
}
const
int
inbatch_seq_idx
=
token_idx
-
cu_seq_lens
[
batch
_idx
[
threadIdx
.
y
]
];
const
int
inbatch_seq_idx
=
token_idx
-
cu_seq_lens
[
batch
];
const
int
block_idx
=
block_table
[
batch_idx
[
threadIdx
.
y
]
*
num_blocks
+
const
int
block_idx
=
inbatch_seq_idx
/
cache_block_size
];
block_table
[
batch
*
num_blocks
+
inbatch_seq_idx
/
cache_block_size
];
const
int64_t
src_block_offset
=
block_idx
*
block_stride
;
const
int64_t
src_block_offset
=
block_idx
*
block_stride
;
const
int64_t
cache_inblock_offset
=
const
int64_t
cache_inblock_offset
=
(
inbatch_seq_idx
%
cache_block_size
)
*
head_dim
+
head_idx
;
(
inbatch_seq_idx
%
cache_block_size
)
*
head_dim
+
head_idx
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment