Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3ac98edc
Unverified
Commit
3ac98edc
authored
Apr 16, 2025
by
billishyahao
Committed by
GitHub
Apr 15, 2025
Browse files
[Feature] add model aware kv ops helper (#16020)
Signed-off-by:
billishyahao
<
bill.he@amd.com
>
parent
966c742e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
123 additions
and
101 deletions
+123
-101
vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py
...uted/kv_transfer/kv_connector/mooncake_store_connector.py
+12
-27
vllm/distributed/kv_transfer/kv_connector/simple_connector.py
.../distributed/kv_transfer/kv_connector/simple_connector.py
+21
-74
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+90
-0
No files found.
vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py
View file @
3ac98edc
# SPDX-License-Identifier: Apache-2.0
"""
MooncakeStore Connector for Distributed Machine Learning Inference
The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
(KV cache producer) and decode vLLM workers (KV cache consumer) using a
database-style KVStore.
...
...
@@ -11,9 +10,10 @@ from typing import TYPE_CHECKING, List, Tuple, Union
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
model_aware_kv_ops_helper
as
kv_helper
)
from
vllm.logger
import
init_logger
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -32,8 +32,7 @@ class MooncakeStoreConnector(KVConnectorBase):
config
:
VllmConfig
,
):
self
.
config
=
config
.
kv_transfer_config
self
.
tp_size
=
config
.
parallel_config
.
tensor_parallel_size
self
.
kv_helper
=
kv_helper
(
config
)
self
.
local_tp_rank
=
local_rank
# Init kv_store
...
...
@@ -80,12 +79,7 @@ class MooncakeStoreConnector(KVConnectorBase):
slot_mapping_flat
=
model_input
.
attn_metadata
.
slot_mapping
.
flatten
()
start_layer
=
model_executable
.
model
.
start_layer
end_layer
=
model_executable
.
model
.
end_layer
model_config
=
model_executable
.
model
.
config
num_heads
=
int
(
model_config
.
num_key_value_heads
/
self
.
tp_size
)
hidden_size
=
model_config
.
hidden_size
num_attention_heads
=
model_config
.
num_attention_heads
head_size
=
int
(
hidden_size
/
num_attention_heads
)
num_heads
,
head_size
=
self
.
kv_helper
.
get_model_args
(
model_executable
)
for
idx
,
slen
in
enumerate
(
seq_lens
):
start_pos
=
sum
(
seq_lens
[:
idx
])
...
...
@@ -97,10 +91,8 @@ class MooncakeStoreConnector(KVConnectorBase):
for
layer_id
in
range
(
start_layer
,
end_layer
):
kv_cache
=
kv_caches
[
layer_id
-
start_layer
]
key_cache
=
kv_cache
[
0
].
reshape
(
-
1
,
num_heads
,
head_size
)
value_cache
=
kv_cache
[
1
].
reshape
(
-
1
,
num_heads
,
head_size
)
key_cache
,
value_cache
=
self
.
kv_helper
.
get_kv_from_cache
(
kv_cache
,
num_heads
,
head_size
)
current_slot_mapping
=
slot_mapping_flat
[
start_pos
:
end_pos
]
keys
.
append
(
key_cache
[
current_slot_mapping
].
unsqueeze
(
0
))
...
...
@@ -173,22 +165,15 @@ class MooncakeStoreConnector(KVConnectorBase):
layer
=
model_executable
.
model
.
layers
[
layer_id
]
# get kvcache object
kv_cache
=
kv_caches
[
layer_id
-
start_layer
]
key_cache
,
value_cache
=
kv_cache
[
0
],
kv_cache
[
1
]
# get remote kvcache
# get remote kvcache
remote_k
,
remote_v
=
remote_kv
[
0
][
layer_id
],
remote_kv
[
1
][
layer_id
]
# use ops.reshape_and_cache_flash to put kv into kvcache
ops
.
reshape_and_cache_flash
(
remote_k
.
to
(
key_cache
.
device
),
remote_v
.
to
(
value_cache
.
device
),
key_cache
,
value_cache
,
slot_mapping
[
start_pos
:
end_pos
],
layer
.
self_attn
.
attn
.
kv_cache_dtype
,
layer
.
self_attn
.
attn
.
_k_scale
,
layer
.
self_attn
.
attn
.
_v_scale
,
)
self
.
kv_helper
.
put_kv_to_cache
(
model_executable
,
remote_k
,
remote_v
,
layer
,
kv_cache
,
slot_mapping
,
start_pos
,
end_pos
)
hidden_or_intermediate_states_for_one_req
.
append
(
hidden
)
...
...
vllm/distributed/kv_transfer/kv_connector/simple_connector.py
View file @
3ac98edc
...
...
@@ -12,10 +12,10 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
model_aware_kv_ops_helper
as
kv_helper
)
from
vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer
import
(
SimpleBuffer
)
from
vllm.logger
import
init_logger
...
...
@@ -37,9 +37,7 @@ class SimpleConnector(KVConnectorBase):
):
self
.
config
=
config
.
kv_transfer_config
self
.
tp_size
=
config
.
parallel_config
.
tensor_parallel_size
self
.
is_deepseek_mla
=
config
.
model_config
.
is_deepseek_mla
self
.
use_mla_opt
=
not
envs
.
VLLM_MLA_DISABLE
self
.
kv_helper
=
kv_helper
(
config
)
if
self
.
config
.
kv_connector
==
"PyNcclConnector"
:
from
vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe
import
(
...
...
@@ -165,31 +163,7 @@ class SimpleConnector(KVConnectorBase):
num_prefill_tokens
=
model_input
.
attn_metadata
.
num_prefill_tokens
start_layer
=
model_executable
.
model
.
start_layer
end_layer
=
model_executable
.
model
.
end_layer
model_config
=
model_executable
.
model
.
config
num_heads
=
int
(
model_config
.
num_key_value_heads
/
self
.
tp_size
)
hidden_size
=
model_config
.
hidden_size
num_attention_heads
=
model_config
.
num_attention_heads
# Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/attention/backends/mla/common.py.
if
self
.
is_deepseek_mla
and
self
.
use_mla_opt
:
head_size
=
model_config
.
kv_lora_rank
+
\
model_config
.
qk_rope_head_dim
num_heads
=
1
elif
self
.
is_deepseek_mla
and
not
self
.
use_mla_opt
:
head_size
=
model_config
.
qk_nope_head_dim
+
\
model_config
.
qk_rope_head_dim
else
:
head_size
=
getattr
(
model_config
,
"head_dim"
,
int
(
hidden_size
//
num_attention_heads
))
num_heads
,
head_size
=
self
.
kv_helper
.
get_model_args
(
model_executable
)
# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
...
...
@@ -212,13 +186,8 @@ class SimpleConnector(KVConnectorBase):
for
layer_id
in
range
(
start_layer
,
end_layer
):
kv_cache
=
kv_caches
[
layer_id
-
start_layer
]
if
self
.
is_deepseek_mla
and
self
.
use_mla_opt
:
key_cache
=
kv_cache
.
reshape
(
-
1
,
num_heads
,
head_size
)
value_cache
=
kv_cache
.
reshape
(
-
1
,
num_heads
,
head_size
)
else
:
key_cache
=
kv_cache
[
0
].
reshape
(
-
1
,
num_heads
,
head_size
)
value_cache
=
kv_cache
[
1
].
reshape
(
-
1
,
num_heads
,
head_size
)
key_cache
,
value_cache
=
self
.
kv_helper
.
get_kv_from_cache
(
kv_cache
,
num_heads
,
head_size
)
current_slot_mapping
=
slot_mapping_flat
[
start_pos
:
end_pos
]
...
...
@@ -248,12 +217,12 @@ class SimpleConnector(KVConnectorBase):
# and hidden states.
bypass_model_exec
=
True
model_config
=
model_executable
.
model
.
config
input_tokens_tensor
=
model_input
.
input_tokens
seq_lens
=
model_input
.
attn_metadata
.
seq_lens
num_prefill_tokens
=
model_input
.
attn_metadata
.
num_prefill_tokens
slot_mapping
=
model_input
.
attn_metadata
.
slot_mapping
.
flatten
()
start_layer
=
model_executable
.
model
.
start_layer
end_layer
=
model_executable
.
model
.
end_layer
hidden_or_intermediate_states_for_one_req
=
[]
...
...
@@ -312,41 +281,19 @@ class SimpleConnector(KVConnectorBase):
end_pos
=
start_pos
+
num_computed_tokens
# put received KV caches into paged memory
for
i
in
range
(
model_executable
.
model
.
start_layer
,
model_executable
.
model
.
end_layer
):
kv_cache
=
kv_caches
[
i
-
model_executable
.
model
.
start_layer
]
layer
=
model_executable
.
model
.
layers
[
i
]
if
self
.
is_deepseek_mla
and
self
.
use_mla_opt
:
layer
.
self_attn
.
attn
=
layer
.
self_attn
.
mla_attn
k_c_normed_k_pe
=
keys
[
i
-
model_executable
.
model
.
start_layer
].
to
(
kv_cache
.
device
).
squeeze
(
1
)
k_c_normed
=
k_c_normed_k_pe
[:,
:
model_config
.
kv_lora_rank
]
k_pe
=
k_c_normed_k_pe
[:,
model_config
.
kv_lora_rank
:]
ops
.
concat_and_cache_mla
(
k_c_normed
,
k_pe
,
kv_cache
,
slot_mapping
[
start_pos
:
end_pos
],
layer
.
self_attn
.
attn
.
kv_cache_dtype
,
layer
.
self_attn
.
attn
.
_k_scale
,
)
else
:
key_cache
,
value_cache
=
kv_cache
[
0
],
kv_cache
[
1
]
ops
.
reshape_and_cache_flash
(
keys
[
i
-
model_executable
.
model
.
start_layer
].
to
(
key_cache
.
device
),
values
[
i
-
model_executable
.
model
.
start_layer
].
to
(
value_cache
.
device
),
key_cache
,
value_cache
,
slot_mapping
[
start_pos
:
end_pos
],
layer
.
self_attn
.
attn
.
kv_cache_dtype
,
layer
.
self_attn
.
attn
.
_k_scale
,
layer
.
self_attn
.
attn
.
_v_scale
,
)
for
cur_layer
in
range
(
start_layer
,
end_layer
):
layer_id
=
cur_layer
-
start_layer
kv_cache
=
kv_caches
[
layer_id
]
layer
=
model_executable
.
model
.
layers
[
cur_layer
]
# get remote kvcache
remote_k
,
remote_v
=
keys
[
layer_id
],
values
[
layer_id
]
self
.
kv_helper
.
put_kv_to_cache
(
model_executable
,
remote_k
,
remote_v
,
layer
,
kv_cache
,
slot_mapping
,
start_pos
,
end_pos
)
hidden_or_intermediate_states_for_one_req
.
append
(
hidden
)
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
0 → 100644
View file @
3ac98edc
# SPDX-License-Identifier: Apache-2.0
"""
KV cache helper for store.
"""
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
model_aware_kv_ops_helper
:
def
__init__
(
self
,
config
:
VllmConfig
):
self
.
is_deepseek_mla
=
config
.
model_config
.
is_deepseek_mla
self
.
use_mla_opt
=
not
envs
.
VLLM_MLA_DISABLE
self
.
tp_size
=
config
.
parallel_config
.
tensor_parallel_size
def
get_model_args
(
self
,
model_executable
:
torch
.
nn
.
Module
):
model_config
=
model_executable
.
model
.
config
self
.
model_executable
=
model_executable
num_heads
=
int
(
model_config
.
num_key_value_heads
/
self
.
tp_size
)
hidden_size
=
model_config
.
hidden_size
num_attention_heads
=
model_config
.
num_attention_heads
# Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/attention/backends/mla/common.py.
if
self
.
is_deepseek_mla
and
self
.
use_mla_opt
:
head_size
=
model_config
.
kv_lora_rank
+
\
model_config
.
qk_rope_head_dim
num_heads
=
1
elif
self
.
is_deepseek_mla
and
not
self
.
use_mla_opt
:
head_size
=
model_config
.
qk_nope_head_dim
+
\
model_config
.
qk_rope_head_dim
else
:
head_size
=
getattr
(
model_config
,
"head_dim"
,
int
(
hidden_size
//
num_attention_heads
))
return
num_heads
,
head_size
def
get_kv_from_cache
(
self
,
kv_cache
,
num_heads
,
head_size
):
if
self
.
is_deepseek_mla
and
self
.
use_mla_opt
:
key_cache
=
kv_cache
.
reshape
(
-
1
,
num_heads
,
head_size
)
value_cache
=
kv_cache
.
reshape
(
-
1
,
num_heads
,
head_size
)
else
:
key_cache
=
kv_cache
[
0
].
reshape
(
-
1
,
num_heads
,
head_size
)
value_cache
=
kv_cache
[
1
].
reshape
(
-
1
,
num_heads
,
head_size
)
return
key_cache
,
value_cache
def
put_kv_to_cache
(
self
,
model_executable
:
torch
.
nn
.
Module
,
keys
,
values
,
layer
,
kv_cache
,
slot_mapping
,
start_pos
,
end_pos
):
model_config
=
model_executable
.
model
.
config
if
self
.
is_deepseek_mla
and
self
.
use_mla_opt
:
layer
.
self_attn
.
attn
=
layer
.
self_attn
.
mla_attn
k_c_normed_k_pe
=
keys
.
squeeze
(
1
)
k_c_normed
=
k_c_normed_k_pe
[:,
:
model_config
.
kv_lora_rank
]
k_pe
=
k_c_normed_k_pe
[:,
model_config
.
kv_lora_rank
:]
ops
.
concat_and_cache_mla
(
k_c_normed
.
to
(
kv_cache
.
device
),
k_pe
.
to
(
kv_cache
.
device
),
kv_cache
,
slot_mapping
[
start_pos
:
end_pos
],
layer
.
self_attn
.
attn
.
kv_cache_dtype
,
layer
.
self_attn
.
attn
.
_k_scale
,
)
else
:
key_cache
,
value_cache
=
kv_cache
[
0
],
kv_cache
[
1
]
ops
.
reshape_and_cache_flash
(
keys
.
to
(
key_cache
.
device
),
values
.
to
(
value_cache
.
device
),
key_cache
,
value_cache
,
slot_mapping
[
start_pos
:
end_pos
],
layer
.
self_attn
.
attn
.
kv_cache_dtype
,
layer
.
self_attn
.
attn
.
_k_scale
,
layer
.
self_attn
.
attn
.
_v_scale
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment