Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
33368140
Unverified
Commit
33368140
authored
Mar 08, 2025
by
Tyler Michael Smith
Committed by
GitHub
Mar 07, 2025
Browse files
[Bugfix][V1] Handle MLA in kv_cache_interface (#14462)
Signed-off-by:
Tyler Michael Smith
<
tyler@neuralmagic.com
>
parent
ef640440
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
10 deletions
+15
-10
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+8
-5
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+3
-2
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+4
-3
No files found.
vllm/v1/kv_cache_interface.py
View file @
33368140
...
@@ -59,6 +59,7 @@ class FullAttentionSpec(KVCacheSpecBase):
...
@@ -59,6 +59,7 @@ class FullAttentionSpec(KVCacheSpecBase):
num_kv_heads
:
int
num_kv_heads
:
int
head_size
:
int
head_size
:
int
dtype
:
torch
.
dtype
dtype
:
torch
.
dtype
use_mla
:
bool
@
property
@
property
def
type_id
(
self
)
->
str
:
def
type_id
(
self
)
->
str
:
...
@@ -66,7 +67,9 @@ class FullAttentionSpec(KVCacheSpecBase):
...
@@ -66,7 +67,9 @@ class FullAttentionSpec(KVCacheSpecBase):
@
property
@
property
def
page_size_bytes
(
self
)
->
int
:
def
page_size_bytes
(
self
)
->
int
:
return
2
*
self
.
block_size
*
self
.
num_kv_heads
*
self
.
head_size
\
# For MLA we only store a single latent vector
coef
=
1
if
self
.
use_mla
else
2
return
coef
*
self
.
block_size
*
self
.
num_kv_heads
*
self
.
head_size
\
*
get_dtype_size
(
self
.
dtype
)
*
get_dtype_size
(
self
.
dtype
)
def
bytes_for_tokens
(
self
,
num_tokens
:
int
)
->
int
:
def
bytes_for_tokens
(
self
,
num_tokens
:
int
)
->
int
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
33368140
...
@@ -1460,13 +1460,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1460,13 +1460,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
forward_ctx
=
self
.
vllm_config
.
compilation_config
.
static_forward_context
forward_ctx
=
self
.
vllm_config
.
compilation_config
.
static_forward_context
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
use_mla
=
self
.
vllm_config
.
model_config
.
use_mla
kv_cache_spec
:
KVCacheSpec
=
{}
kv_cache_spec
:
KVCacheSpec
=
{}
for
layer_name
,
attn_module
in
forward_ctx
.
items
():
for
layer_name
,
attn_module
in
forward_ctx
.
items
():
if
isinstance
(
attn_module
,
FusedMoE
):
if
isinstance
(
attn_module
,
FusedMoE
):
continue
continue
# TODO: Support other attention modules, e.g., sliding window,
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention
, MLA.
# cross-attention
assert
isinstance
(
attn_module
,
Attention
)
assert
isinstance
(
attn_module
,
Attention
)
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
kv_cache_spec
[
layer_name
]
=
FullAttentionSpec
(
kv_cache_spec
[
layer_name
]
=
FullAttentionSpec
(
...
@@ -1474,7 +1475,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1474,7 +1475,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_kv_heads
=
attn_module
.
num_kv_heads
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
head_size
=
attn_module
.
head_size
,
dtype
=
attn_module
.
dtype
,
dtype
=
attn_module
.
dtype
,
)
use_mla
=
use_mla
)
elif
attn_module
.
attn_type
in
(
AttentionType
.
ENCODER
,
elif
attn_module
.
attn_type
in
(
AttentionType
.
ENCODER
,
AttentionType
.
ENCODER_ONLY
):
AttentionType
.
ENCODER_ONLY
):
# encoder-only attention does not need KV cache.
# encoder-only attention does not need KV cache.
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
33368140
...
@@ -323,6 +323,7 @@ class TPUModelRunner:
...
@@ -323,6 +323,7 @@ class TPUModelRunner:
num_kv_heads
=
attn_module
.
num_kv_heads
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
head_size
=
attn_module
.
head_size
,
dtype
=
attn_module
.
dtype
,
dtype
=
attn_module
.
dtype
,
use_mla
=
False
,
)
)
elif
attn_module
.
attn_type
in
(
AttentionType
.
ENCODER
,
elif
attn_module
.
attn_type
in
(
AttentionType
.
ENCODER
,
AttentionType
.
ENCODER_ONLY
):
AttentionType
.
ENCODER_ONLY
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment