Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
33368140
Unverified
Commit
33368140
authored
Mar 08, 2025
by
Tyler Michael Smith
Committed by
GitHub
Mar 07, 2025
Browse files
[Bugfix][V1] Handle MLA in kv_cache_interface (#14462)
Signed-off-by:
Tyler Michael Smith
<
tyler@neuralmagic.com
>
parent
ef640440
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
10 deletions
+15
-10
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+8
-5
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+3
-2
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+4
-3
No files found.
vllm/v1/kv_cache_interface.py
View file @
33368140
...
@@ -23,9 +23,9 @@ class KVCacheSpecBase:
...
@@ -23,9 +23,9 @@ class KVCacheSpecBase:
def
type_id
(
self
)
->
str
:
def
type_id
(
self
)
->
str
:
"""
"""
The type identifier of this KV cache.
The type identifier of this KV cache.
Return different strings for layers with different KV cache type (e.g.,
Return different strings for layers with different KV cache type (e.g.,
different number of tokens like full attention vs sliding window
different number of tokens like full attention vs sliding window
attention, different KV cache size per token like layers with different
attention, different KV cache size per token like layers with different
number of heads)
number of heads)
Returns:
Returns:
...
@@ -59,6 +59,7 @@ class FullAttentionSpec(KVCacheSpecBase):
...
@@ -59,6 +59,7 @@ class FullAttentionSpec(KVCacheSpecBase):
num_kv_heads
:
int
num_kv_heads
:
int
head_size
:
int
head_size
:
int
dtype
:
torch
.
dtype
dtype
:
torch
.
dtype
use_mla
:
bool
@
property
@
property
def
type_id
(
self
)
->
str
:
def
type_id
(
self
)
->
str
:
...
@@ -66,7 +67,9 @@ class FullAttentionSpec(KVCacheSpecBase):
...
@@ -66,7 +67,9 @@ class FullAttentionSpec(KVCacheSpecBase):
@
property
@
property
def
page_size_bytes
(
self
)
->
int
:
def
page_size_bytes
(
self
)
->
int
:
return
2
*
self
.
block_size
*
self
.
num_kv_heads
*
self
.
head_size
\
# For MLA we only store a single latent vector
coef
=
1
if
self
.
use_mla
else
2
return
coef
*
self
.
block_size
*
self
.
num_kv_heads
*
self
.
head_size
\
*
get_dtype_size
(
self
.
dtype
)
*
get_dtype_size
(
self
.
dtype
)
def
bytes_for_tokens
(
self
,
num_tokens
:
int
)
->
int
:
def
bytes_for_tokens
(
self
,
num_tokens
:
int
)
->
int
:
...
@@ -104,7 +107,7 @@ class KVCacheConfig:
...
@@ -104,7 +107,7 @@ class KVCacheConfig:
2. (not implemented yet) A model with the same number of full attention
2. (not implemented yet) A model with the same number of full attention
layers and sliding window attention layers: two groups, one for full
layers and sliding window attention layers: two groups, one for full
attention layers and one for sliding window attention layers.
attention layers and one for sliding window attention layers.
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
"""
"""
groups
:
list
[
list
[
str
]]
groups
:
list
[
list
[
str
]]
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
33368140
...
@@ -1460,13 +1460,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1460,13 +1460,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
forward_ctx
=
self
.
vllm_config
.
compilation_config
.
static_forward_context
forward_ctx
=
self
.
vllm_config
.
compilation_config
.
static_forward_context
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
use_mla
=
self
.
vllm_config
.
model_config
.
use_mla
kv_cache_spec
:
KVCacheSpec
=
{}
kv_cache_spec
:
KVCacheSpec
=
{}
for
layer_name
,
attn_module
in
forward_ctx
.
items
():
for
layer_name
,
attn_module
in
forward_ctx
.
items
():
if
isinstance
(
attn_module
,
FusedMoE
):
if
isinstance
(
attn_module
,
FusedMoE
):
continue
continue
# TODO: Support other attention modules, e.g., sliding window,
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention
, MLA.
# cross-attention
assert
isinstance
(
attn_module
,
Attention
)
assert
isinstance
(
attn_module
,
Attention
)
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
kv_cache_spec
[
layer_name
]
=
FullAttentionSpec
(
kv_cache_spec
[
layer_name
]
=
FullAttentionSpec
(
...
@@ -1474,7 +1475,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1474,7 +1475,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_kv_heads
=
attn_module
.
num_kv_heads
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
head_size
=
attn_module
.
head_size
,
dtype
=
attn_module
.
dtype
,
dtype
=
attn_module
.
dtype
,
)
use_mla
=
use_mla
)
elif
attn_module
.
attn_type
in
(
AttentionType
.
ENCODER
,
elif
attn_module
.
attn_type
in
(
AttentionType
.
ENCODER
,
AttentionType
.
ENCODER_ONLY
):
AttentionType
.
ENCODER_ONLY
):
# encoder-only attention does not need KV cache.
# encoder-only attention does not need KV cache.
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
33368140
...
@@ -303,10 +303,10 @@ class TPUModelRunner:
...
@@ -303,10 +303,10 @@ class TPUModelRunner:
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
"""
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Attention module in the static forward context.
Returns:
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
format. Layers that do not need KV cache are not included.
"""
"""
...
@@ -323,6 +323,7 @@ class TPUModelRunner:
...
@@ -323,6 +323,7 @@ class TPUModelRunner:
num_kv_heads
=
attn_module
.
num_kv_heads
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
head_size
=
attn_module
.
head_size
,
dtype
=
attn_module
.
dtype
,
dtype
=
attn_module
.
dtype
,
use_mla
=
False
,
)
)
elif
attn_module
.
attn_type
in
(
AttentionType
.
ENCODER
,
elif
attn_module
.
attn_type
in
(
AttentionType
.
ENCODER
,
AttentionType
.
ENCODER_ONLY
):
AttentionType
.
ENCODER_ONLY
):
...
@@ -764,7 +765,7 @@ class TPUModelRunner:
...
@@ -764,7 +765,7 @@ class TPUModelRunner:
"""
"""
Initialize KV cache based on `kv_cache_config`.
Initialize KV cache based on `kv_cache_config`.
Args:
Args:
kv_cache_config: Configuration for the KV cache, including the KV
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
cache size of each layer
"""
"""
if
len
(
kv_cache_config
.
groups
)
>
1
:
if
len
(
kv_cache_config
.
groups
)
>
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment