Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3ac98edc
Unverified
Commit
3ac98edc
authored
Apr 16, 2025
by
billishyahao
Committed by
GitHub
Apr 15, 2025
Browse files
[Feature] add model aware kv ops helper (#16020)
Signed-off-by:
billishyahao
<
bill.he@amd.com
>
parent
966c742e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
123 additions
and
101 deletions
+123
-101
vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py
...uted/kv_transfer/kv_connector/mooncake_store_connector.py
+12
-27
vllm/distributed/kv_transfer/kv_connector/simple_connector.py
.../distributed/kv_transfer/kv_connector/simple_connector.py
+21
-74
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+90
-0
No files found.
vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py
View file @
3ac98edc
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""
"""
MooncakeStore Connector for Distributed Machine Learning Inference
MooncakeStore Connector for Distributed Machine Learning Inference
The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
(KV cache producer) and decode vLLM workers (KV cache consumer) using a
(KV cache producer) and decode vLLM workers (KV cache consumer) using a
database-style KVStore.
database-style KVStore.
...
@@ -11,9 +10,10 @@ from typing import TYPE_CHECKING, List, Tuple, Union
...
@@ -11,9 +10,10 @@ from typing import TYPE_CHECKING, List, Tuple, Union
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
model_aware_kv_ops_helper
as
kv_helper
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -32,8 +32,7 @@ class MooncakeStoreConnector(KVConnectorBase):
...
@@ -32,8 +32,7 @@ class MooncakeStoreConnector(KVConnectorBase):
config
:
VllmConfig
,
config
:
VllmConfig
,
):
):
self
.
config
=
config
.
kv_transfer_config
self
.
config
=
config
.
kv_transfer_config
self
.
tp_size
=
config
.
parallel_config
.
tensor_parallel_size
self
.
kv_helper
=
kv_helper
(
config
)
self
.
local_tp_rank
=
local_rank
self
.
local_tp_rank
=
local_rank
# Init kv_store
# Init kv_store
...
@@ -80,12 +79,7 @@ class MooncakeStoreConnector(KVConnectorBase):
...
@@ -80,12 +79,7 @@ class MooncakeStoreConnector(KVConnectorBase):
slot_mapping_flat
=
model_input
.
attn_metadata
.
slot_mapping
.
flatten
()
slot_mapping_flat
=
model_input
.
attn_metadata
.
slot_mapping
.
flatten
()
start_layer
=
model_executable
.
model
.
start_layer
start_layer
=
model_executable
.
model
.
start_layer
end_layer
=
model_executable
.
model
.
end_layer
end_layer
=
model_executable
.
model
.
end_layer
num_heads
,
head_size
=
self
.
kv_helper
.
get_model_args
(
model_executable
)
model_config
=
model_executable
.
model
.
config
num_heads
=
int
(
model_config
.
num_key_value_heads
/
self
.
tp_size
)
hidden_size
=
model_config
.
hidden_size
num_attention_heads
=
model_config
.
num_attention_heads
head_size
=
int
(
hidden_size
/
num_attention_heads
)
for
idx
,
slen
in
enumerate
(
seq_lens
):
for
idx
,
slen
in
enumerate
(
seq_lens
):
start_pos
=
sum
(
seq_lens
[:
idx
])
start_pos
=
sum
(
seq_lens
[:
idx
])
...
@@ -97,10 +91,8 @@ class MooncakeStoreConnector(KVConnectorBase):
...
@@ -97,10 +91,8 @@ class MooncakeStoreConnector(KVConnectorBase):
for
layer_id
in
range
(
start_layer
,
end_layer
):
for
layer_id
in
range
(
start_layer
,
end_layer
):
kv_cache
=
kv_caches
[
layer_id
-
start_layer
]
kv_cache
=
kv_caches
[
layer_id
-
start_layer
]
key_cache
,
value_cache
=
self
.
kv_helper
.
get_kv_from_cache
(
key_cache
=
kv_cache
[
0
].
reshape
(
-
1
,
num_heads
,
head_size
)
kv_cache
,
num_heads
,
head_size
)
value_cache
=
kv_cache
[
1
].
reshape
(
-
1
,
num_heads
,
head_size
)
current_slot_mapping
=
slot_mapping_flat
[
start_pos
:
end_pos
]
current_slot_mapping
=
slot_mapping_flat
[
start_pos
:
end_pos
]
keys
.
append
(
key_cache
[
current_slot_mapping
].
unsqueeze
(
0
))
keys
.
append
(
key_cache
[
current_slot_mapping
].
unsqueeze
(
0
))
...
@@ -173,22 +165,15 @@ class MooncakeStoreConnector(KVConnectorBase):
...
@@ -173,22 +165,15 @@ class MooncakeStoreConnector(KVConnectorBase):
layer
=
model_executable
.
model
.
layers
[
layer_id
]
layer
=
model_executable
.
model
.
layers
[
layer_id
]
# get kvcache object
# get kvcache object
kv_cache
=
kv_caches
[
layer_id
-
start_layer
]
kv_cache
=
kv_caches
[
layer_id
-
start_layer
]
key_cache
,
value_cache
=
kv_cache
[
0
],
kv_cache
[
1
]
# get remote kvcache
# get remote kvcache
remote_k
,
remote_v
=
remote_kv
[
0
][
layer_id
],
remote_kv
[
1
][
remote_k
,
remote_v
=
remote_kv
[
0
][
layer_id
],
remote_kv
[
1
][
layer_id
]
layer_id
]
# use ops.reshape_and_cache_flash to put kv into kvcache
ops
.
reshape_and_cache_flash
(
self
.
kv_helper
.
put_kv_to_cache
(
model_executable
,
remote_k
,
remote_k
.
to
(
key_cache
.
device
),
remote_v
,
layer
,
kv_cache
,
remote_v
.
to
(
value_cache
.
device
),
slot_mapping
,
start_pos
,
key_cache
,
end_pos
)
value_cache
,
slot_mapping
[
start_pos
:
end_pos
],
layer
.
self_attn
.
attn
.
kv_cache_dtype
,
layer
.
self_attn
.
attn
.
_k_scale
,
layer
.
self_attn
.
attn
.
_v_scale
,
)
hidden_or_intermediate_states_for_one_req
.
append
(
hidden
)
hidden_or_intermediate_states_for_one_req
.
append
(
hidden
)
...
...
vllm/distributed/kv_transfer/kv_connector/simple_connector.py
View file @
3ac98edc
...
@@ -12,10 +12,10 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
...
@@ -12,10 +12,10 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import
torch
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
model_aware_kv_ops_helper
as
kv_helper
)
from
vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer
import
(
from
vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer
import
(
SimpleBuffer
)
SimpleBuffer
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -37,9 +37,7 @@ class SimpleConnector(KVConnectorBase):
...
@@ -37,9 +37,7 @@ class SimpleConnector(KVConnectorBase):
):
):
self
.
config
=
config
.
kv_transfer_config
self
.
config
=
config
.
kv_transfer_config
self
.
tp_size
=
config
.
parallel_config
.
tensor_parallel_size
self
.
kv_helper
=
kv_helper
(
config
)
self
.
is_deepseek_mla
=
config
.
model_config
.
is_deepseek_mla
self
.
use_mla_opt
=
not
envs
.
VLLM_MLA_DISABLE
if
self
.
config
.
kv_connector
==
"PyNcclConnector"
:
if
self
.
config
.
kv_connector
==
"PyNcclConnector"
:
from
vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe
import
(
from
vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe
import
(
...
@@ -165,31 +163,7 @@ class SimpleConnector(KVConnectorBase):
...
@@ -165,31 +163,7 @@ class SimpleConnector(KVConnectorBase):
num_prefill_tokens
=
model_input
.
attn_metadata
.
num_prefill_tokens
num_prefill_tokens
=
model_input
.
attn_metadata
.
num_prefill_tokens
start_layer
=
model_executable
.
model
.
start_layer
start_layer
=
model_executable
.
model
.
start_layer
end_layer
=
model_executable
.
model
.
end_layer
end_layer
=
model_executable
.
model
.
end_layer
num_heads
,
head_size
=
self
.
kv_helper
.
get_model_args
(
model_executable
)
model_config
=
model_executable
.
model
.
config
num_heads
=
int
(
model_config
.
num_key_value_heads
/
self
.
tp_size
)
hidden_size
=
model_config
.
hidden_size
num_attention_heads
=
model_config
.
num_attention_heads
# Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/attention/backends/mla/common.py.
if
self
.
is_deepseek_mla
and
self
.
use_mla_opt
:
head_size
=
model_config
.
kv_lora_rank
+
\
model_config
.
qk_rope_head_dim
num_heads
=
1
elif
self
.
is_deepseek_mla
and
not
self
.
use_mla_opt
:
head_size
=
model_config
.
qk_nope_head_dim
+
\
model_config
.
qk_rope_head_dim
else
:
head_size
=
getattr
(
model_config
,
"head_dim"
,
int
(
hidden_size
//
num_attention_heads
))
# query_lens contains new KV caches that are added to vLLM.
# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
# so we will send them to decode instance
...
@@ -212,13 +186,8 @@ class SimpleConnector(KVConnectorBase):
...
@@ -212,13 +186,8 @@ class SimpleConnector(KVConnectorBase):
for
layer_id
in
range
(
start_layer
,
end_layer
):
for
layer_id
in
range
(
start_layer
,
end_layer
):
kv_cache
=
kv_caches
[
layer_id
-
start_layer
]
kv_cache
=
kv_caches
[
layer_id
-
start_layer
]
key_cache
,
value_cache
=
self
.
kv_helper
.
get_kv_from_cache
(
if
self
.
is_deepseek_mla
and
self
.
use_mla_opt
:
kv_cache
,
num_heads
,
head_size
)
key_cache
=
kv_cache
.
reshape
(
-
1
,
num_heads
,
head_size
)
value_cache
=
kv_cache
.
reshape
(
-
1
,
num_heads
,
head_size
)
else
:
key_cache
=
kv_cache
[
0
].
reshape
(
-
1
,
num_heads
,
head_size
)
value_cache
=
kv_cache
[
1
].
reshape
(
-
1
,
num_heads
,
head_size
)
current_slot_mapping
=
slot_mapping_flat
[
start_pos
:
end_pos
]
current_slot_mapping
=
slot_mapping_flat
[
start_pos
:
end_pos
]
...
@@ -248,12 +217,12 @@ class SimpleConnector(KVConnectorBase):
...
@@ -248,12 +217,12 @@ class SimpleConnector(KVConnectorBase):
# and hidden states.
# and hidden states.
bypass_model_exec
=
True
bypass_model_exec
=
True
model_config
=
model_executable
.
model
.
config
input_tokens_tensor
=
model_input
.
input_tokens
input_tokens_tensor
=
model_input
.
input_tokens
seq_lens
=
model_input
.
attn_metadata
.
seq_lens
seq_lens
=
model_input
.
attn_metadata
.
seq_lens
num_prefill_tokens
=
model_input
.
attn_metadata
.
num_prefill_tokens
num_prefill_tokens
=
model_input
.
attn_metadata
.
num_prefill_tokens
slot_mapping
=
model_input
.
attn_metadata
.
slot_mapping
.
flatten
()
slot_mapping
=
model_input
.
attn_metadata
.
slot_mapping
.
flatten
()
start_layer
=
model_executable
.
model
.
start_layer
end_layer
=
model_executable
.
model
.
end_layer
hidden_or_intermediate_states_for_one_req
=
[]
hidden_or_intermediate_states_for_one_req
=
[]
...
@@ -312,41 +281,19 @@ class SimpleConnector(KVConnectorBase):
...
@@ -312,41 +281,19 @@ class SimpleConnector(KVConnectorBase):
end_pos
=
start_pos
+
num_computed_tokens
end_pos
=
start_pos
+
num_computed_tokens
# put received KV caches into paged memory
# put received KV caches into paged memory
for
i
in
range
(
model_executable
.
model
.
start_layer
,
for
cur_layer
in
range
(
start_layer
,
end_layer
):
model_executable
.
model
.
end_layer
):
layer_id
=
cur_layer
-
start_layer
kv_cache
=
kv_caches
[
i
-
model_executable
.
model
.
start_layer
]
kv_cache
=
kv_caches
[
layer_id
]
layer
=
model_executable
.
model
.
layers
[
i
]
layer
=
model_executable
.
model
.
layers
[
cur_layer
]
if
self
.
is_deepseek_mla
and
self
.
use_mla_opt
:
# get remote kvcache
layer
.
self_attn
.
attn
=
layer
.
self_attn
.
mla_attn
remote_k
,
remote_v
=
keys
[
layer_id
],
values
[
layer_id
]
k_c_normed_k_pe
=
keys
[
i
-
model_executable
.
model
.
start_layer
].
to
(
self
.
kv_helper
.
put_kv_to_cache
(
model_executable
,
remote_k
,
kv_cache
.
device
).
squeeze
(
1
)
remote_v
,
layer
,
kv_cache
,
k_c_normed
=
k_c_normed_k_pe
[:,
:
model_config
.
kv_lora_rank
]
slot_mapping
,
start_pos
,
k_pe
=
k_c_normed_k_pe
[:,
model_config
.
kv_lora_rank
:]
end_pos
)
ops
.
concat_and_cache_mla
(
k_c_normed
,
k_pe
,
kv_cache
,
slot_mapping
[
start_pos
:
end_pos
],
layer
.
self_attn
.
attn
.
kv_cache_dtype
,
layer
.
self_attn
.
attn
.
_k_scale
,
)
else
:
key_cache
,
value_cache
=
kv_cache
[
0
],
kv_cache
[
1
]
ops
.
reshape_and_cache_flash
(
keys
[
i
-
model_executable
.
model
.
start_layer
].
to
(
key_cache
.
device
),
values
[
i
-
model_executable
.
model
.
start_layer
].
to
(
value_cache
.
device
),
key_cache
,
value_cache
,
slot_mapping
[
start_pos
:
end_pos
],
layer
.
self_attn
.
attn
.
kv_cache_dtype
,
layer
.
self_attn
.
attn
.
_k_scale
,
layer
.
self_attn
.
attn
.
_v_scale
,
)
hidden_or_intermediate_states_for_one_req
.
append
(
hidden
)
hidden_or_intermediate_states_for_one_req
.
append
(
hidden
)
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
0 → 100644
View file @
3ac98edc
# SPDX-License-Identifier: Apache-2.0
"""
KV cache helper for store.
"""
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
model_aware_kv_ops_helper
:
def
__init__
(
self
,
config
:
VllmConfig
):
self
.
is_deepseek_mla
=
config
.
model_config
.
is_deepseek_mla
self
.
use_mla_opt
=
not
envs
.
VLLM_MLA_DISABLE
self
.
tp_size
=
config
.
parallel_config
.
tensor_parallel_size
def
get_model_args
(
self
,
model_executable
:
torch
.
nn
.
Module
):
model_config
=
model_executable
.
model
.
config
self
.
model_executable
=
model_executable
num_heads
=
int
(
model_config
.
num_key_value_heads
/
self
.
tp_size
)
hidden_size
=
model_config
.
hidden_size
num_attention_heads
=
model_config
.
num_attention_heads
# Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/attention/backends/mla/common.py.
if
self
.
is_deepseek_mla
and
self
.
use_mla_opt
:
head_size
=
model_config
.
kv_lora_rank
+
\
model_config
.
qk_rope_head_dim
num_heads
=
1
elif
self
.
is_deepseek_mla
and
not
self
.
use_mla_opt
:
head_size
=
model_config
.
qk_nope_head_dim
+
\
model_config
.
qk_rope_head_dim
else
:
head_size
=
getattr
(
model_config
,
"head_dim"
,
int
(
hidden_size
//
num_attention_heads
))
return
num_heads
,
head_size
def
get_kv_from_cache
(
self
,
kv_cache
,
num_heads
,
head_size
):
if
self
.
is_deepseek_mla
and
self
.
use_mla_opt
:
key_cache
=
kv_cache
.
reshape
(
-
1
,
num_heads
,
head_size
)
value_cache
=
kv_cache
.
reshape
(
-
1
,
num_heads
,
head_size
)
else
:
key_cache
=
kv_cache
[
0
].
reshape
(
-
1
,
num_heads
,
head_size
)
value_cache
=
kv_cache
[
1
].
reshape
(
-
1
,
num_heads
,
head_size
)
return
key_cache
,
value_cache
def
put_kv_to_cache
(
self
,
model_executable
:
torch
.
nn
.
Module
,
keys
,
values
,
layer
,
kv_cache
,
slot_mapping
,
start_pos
,
end_pos
):
model_config
=
model_executable
.
model
.
config
if
self
.
is_deepseek_mla
and
self
.
use_mla_opt
:
layer
.
self_attn
.
attn
=
layer
.
self_attn
.
mla_attn
k_c_normed_k_pe
=
keys
.
squeeze
(
1
)
k_c_normed
=
k_c_normed_k_pe
[:,
:
model_config
.
kv_lora_rank
]
k_pe
=
k_c_normed_k_pe
[:,
model_config
.
kv_lora_rank
:]
ops
.
concat_and_cache_mla
(
k_c_normed
.
to
(
kv_cache
.
device
),
k_pe
.
to
(
kv_cache
.
device
),
kv_cache
,
slot_mapping
[
start_pos
:
end_pos
],
layer
.
self_attn
.
attn
.
kv_cache_dtype
,
layer
.
self_attn
.
attn
.
_k_scale
,
)
else
:
key_cache
,
value_cache
=
kv_cache
[
0
],
kv_cache
[
1
]
ops
.
reshape_and_cache_flash
(
keys
.
to
(
key_cache
.
device
),
values
.
to
(
value_cache
.
device
),
key_cache
,
value_cache
,
slot_mapping
[
start_pos
:
end_pos
],
layer
.
self_attn
.
attn
.
kv_cache_dtype
,
layer
.
self_attn
.
attn
.
_k_scale
,
layer
.
self_attn
.
attn
.
_v_scale
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment