Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a5d06dc5
Unverified
Commit
a5d06dc5
authored
Mar 11, 2026
by
Julien Denize
Committed by
GitHub
Mar 11, 2026
Browse files
Add 320 dimension size support to MLA (#36161)
Signed-off-by:
Julien Denize
<
julien.denize@mistral.ai
>
parent
5efa206a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
9 deletions
+25
-9
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+19
-6
tests/kernels/attention/test_cache.py
tests/kernels/attention/test_cache.py
+5
-2
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+1
-1
No files found.
csrc/cache_kernels.cu
View file @
a5d06dc5
...
...
@@ -919,8 +919,8 @@ __global__ void gather_and_maybe_dequant_cache(
// SCALAR_T is the data type of the destination tensor.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE
)
\
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE,
576,
\
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE
, ENTRY_SZ)
\
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE,
ENTRY_SZ,
\
thread_block_size> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
...
...
@@ -931,6 +931,12 @@ __global__ void gather_and_maybe_dequant_cache(
dst_entry_stride, reinterpret_cast<const float*>(scale.data_ptr()), \
seq_starts_ptr);
#define CALL_GATHER_CACHE_576(SCALAR_T, CACHE_T, KV_DTYPE) \
CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE, 576)
#define CALL_GATHER_CACHE_320(SCALAR_T, CACHE_T, KV_DTYPE) \
CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE, 320)
// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
...
...
@@ -960,9 +966,10 @@ void gather_and_maybe_dequant_cache(
TORCH_CHECK
(
seq_starts
.
value
().
dtype
()
==
torch
::
kInt32
,
"seq_starts must be int32"
);
}
TORCH_CHECK
(
head_dim
==
576
,
"gather_and_maybe_dequant_cache only support the head_dim to 576 "
"for better performance"
)
TORCH_CHECK
(
head_dim
==
320
||
head_dim
==
576
,
"gather_and_maybe_dequant_cache only support the head_dim to 320 or 576 "
"for better performance"
)
TORCH_CHECK
(
src_cache
.
device
()
==
dst
.
device
(),
"src_cache and dst must be on the same device"
);
...
...
@@ -987,7 +994,13 @@ void gather_and_maybe_dequant_cache(
const
int32_t
*
seq_starts_ptr
=
seq_starts
.
has_value
()
?
seq_starts
.
value
().
data_ptr
<
int32_t
>
()
:
nullptr
;
DISPATCH_BY_KV_CACHE_DTYPE
(
dst
.
dtype
(),
kv_cache_dtype
,
CALL_GATHER_CACHE
);
if
(
head_dim
==
576
)
{
DISPATCH_BY_KV_CACHE_DTYPE
(
dst
.
dtype
(),
kv_cache_dtype
,
CALL_GATHER_CACHE_576
);
}
else
{
DISPATCH_BY_KV_CACHE_DTYPE
(
dst
.
dtype
(),
kv_cache_dtype
,
CALL_GATHER_CACHE_320
);
}
}
namespace
vllm
{
...
...
tests/kernels/attention/test_cache.py
View file @
a5d06dc5
...
...
@@ -23,7 +23,7 @@ CACHE_LAYOUTS = ["NHD", "HND"]
KV_SCALE_TYPES
=
[
"tensor"
,
"attn_head"
]
# Parameters for MLA tests.
KV_LORA_RANKS
=
[
512
]
KV_LORA_RANKS
=
[
256
,
512
]
QK_ROPE_HEAD_DIMS
=
[
64
]
NUM_TOKENS_MLA
=
[
42
]
BLOCK_SIZES_MLA
=
[
16
]
...
...
@@ -627,6 +627,8 @@ def test_concat_and_cache_ds_mla(
pytest
.
skip
(
"concat_and_cache_mla doesn't support fp8_ds_mla on ROCm"
)
if
dtype
.
itemsize
!=
2
:
pytest
.
skip
(
"ds_mla only supports 16-bit input"
)
if
kv_lora_rank
!=
512
:
pytest
.
skip
(
"fp8_ds_mla requires kv_lora_rank == 512"
)
kv_cache_dtype
=
"fp8_ds_mla"
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
...
...
@@ -663,7 +665,8 @@ def test_concat_and_cache_ds_mla(
ref_cache_32bit
=
ref_cache_slice
.
view
(
torch
.
float32
)
kv_c_data
=
kv_c
[
i
]
for
tile_idx
in
range
(
4
):
num_tiles
=
kv_lora_rank
//
128
for
tile_idx
in
range
(
num_tiles
):
tile_start
=
tile_idx
*
128
tile_end
=
(
tile_idx
+
1
)
*
128
tile_data
[:]
=
kv_c_data
[
tile_start
:
tile_end
]
...
...
vllm/model_executor/layers/attention/mla_attention.py
View file @
a5d06dc5
...
...
@@ -1148,7 +1148,7 @@ class MLACommonBackend(AttentionBackend):
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
return
[
576
]
return
[
320
,
576
]
@
classmethod
def
is_mla
(
cls
)
->
bool
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment