Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d75f22e
Commit
8d75f22e
authored
Dec 13, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori
parents
ce888aa4
7d80c73d
Changes
679
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1602 additions
and
1180 deletions
+1602
-1180
vllm/distributed/ec_transfer/ec_connector/example_connector.py
...distributed/ec_transfer/ec_connector/example_connector.py
+4
-4
vllm/distributed/ec_transfer/ec_connector/factory.py
vllm/distributed/ec_transfer/ec_connector/factory.py
+3
-3
vllm/distributed/eplb/__init__.py
vllm/distributed/eplb/__init__.py
+1
-6
vllm/distributed/eplb/eplb_state.py
vllm/distributed/eplb/eplb_state.py
+15
-5
vllm/distributed/eplb/policy/__init__.py
vllm/distributed/eplb/policy/__init__.py
+19
-0
vllm/distributed/eplb/policy/abstract.py
vllm/distributed/eplb/policy/abstract.py
+40
-0
vllm/distributed/eplb/policy/default.py
vllm/distributed/eplb/policy/default.py
+267
-0
vllm/distributed/eplb/rebalance_algo.py
vllm/distributed/eplb/rebalance_algo.py
+0
-260
vllm/distributed/eplb/rebalance_execute.py
vllm/distributed/eplb/rebalance_execute.py
+0
-3
vllm/distributed/kv_transfer/kv_connector/factory.py
vllm/distributed/kv_transfer/kv_connector/factory.py
+8
-3
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+125
-86
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+15
-1
vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py
...tributed/kv_transfer/kv_connector/v1/example_connector.py
+5
-5
vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py
...ributed/kv_transfer/kv_connector/v1/mooncake_connector.py
+914
-0
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
...istributed/kv_transfer/kv_connector/v1/multi_connector.py
+4
-0
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+182
-153
vllm/distributed/kv_transfer/kv_lookup_buffer/base.py
vllm/distributed/kv_transfer/kv_lookup_buffer/base.py
+0
-179
vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py
...istributed/kv_transfer/kv_lookup_buffer/mooncake_store.py
+0
-164
vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
...distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
+0
-242
vllm/distributed/kv_transfer/kv_pipe/base.py
vllm/distributed/kv_transfer/kv_pipe/base.py
+0
-66
No files found.
Too many changes to show.
To preserve performance only
679 of 679+
files are displayed.
Plain diff
Email patch
vllm/distributed/ec_transfer/ec_connector/
shared_storag
e_connector.py
→
vllm/distributed/ec_transfer/ec_connector/
exampl
e_connector.py
View file @
8d75f22e
...
@@ -32,7 +32,7 @@ class MMMeta:
...
@@ -32,7 +32,7 @@ class MMMeta:
@
dataclass
@
dataclass
class
EC
SharedStorag
eConnectorMetadata
(
ECConnectorMetadata
):
class
EC
Exampl
eConnectorMetadata
(
ECConnectorMetadata
):
mm_datas
:
list
[
MMMeta
]
mm_datas
:
list
[
MMMeta
]
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -42,7 +42,7 @@ class ECSharedStorageConnectorMetadata(ECConnectorMetadata):
...
@@ -42,7 +42,7 @@ class ECSharedStorageConnectorMetadata(ECConnectorMetadata):
self
.
mm_datas
.
append
(
mm_data
)
self
.
mm_datas
.
append
(
mm_data
)
class
EC
SharedStorag
eConnector
(
ECConnectorBase
):
class
EC
Exampl
eConnector
(
ECConnectorBase
):
# NOTE: This is Simple debug implementation of the EC connector.
# NOTE: This is Simple debug implementation of the EC connector.
# It save / load the EC cache to / from the disk.
# It save / load the EC cache to / from the disk.
...
@@ -76,7 +76,7 @@ class ECSharedStorageConnector(ECConnectorBase):
...
@@ -76,7 +76,7 @@ class ECSharedStorageConnector(ECConnectorBase):
# Get the metadata
# Get the metadata
metadata
:
ECConnectorMetadata
=
self
.
_get_connector_metadata
()
metadata
:
ECConnectorMetadata
=
self
.
_get_connector_metadata
()
assert
isinstance
(
metadata
,
EC
SharedStorag
eConnectorMetadata
)
assert
isinstance
(
metadata
,
EC
Exampl
eConnectorMetadata
)
assert
encoder_cache
is
not
None
assert
encoder_cache
is
not
None
if
metadata
is
None
:
if
metadata
is
None
:
logger
.
warning
(
logger
.
warning
(
...
@@ -160,7 +160,7 @@ class ECSharedStorageConnector(ECConnectorBase):
...
@@ -160,7 +160,7 @@ class ECSharedStorageConnector(ECConnectorBase):
Args:
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
scheduler_output (SchedulerOutput): the scheduler output object.
"""
"""
meta
=
EC
SharedStorag
eConnectorMetadata
()
meta
=
EC
Exampl
eConnectorMetadata
()
for
mm_hash
,
num_encoder_token
in
self
.
_mm_datas_need_loads
.
items
():
for
mm_hash
,
num_encoder_token
in
self
.
_mm_datas_need_loads
.
items
():
meta
.
add_mm_data
(
MMMeta
.
make_meta
(
mm_hash
,
num_encoder_token
))
meta
.
add_mm_data
(
MMMeta
.
make_meta
(
mm_hash
,
num_encoder_token
))
self
.
_mm_datas_need_loads
.
clear
()
self
.
_mm_datas_need_loads
.
clear
()
...
...
vllm/distributed/ec_transfer/ec_connector/factory.py
View file @
8d75f22e
...
@@ -79,7 +79,7 @@ class ECConnectorFactory:
...
@@ -79,7 +79,7 @@ class ECConnectorFactory:
# only load the files corresponding to the current connector.
# only load the files corresponding to the current connector.
ECConnectorFactory
.
register_connector
(
ECConnectorFactory
.
register_connector
(
"EC
SharedStorag
eConnector"
,
"EC
Exampl
eConnector"
,
"vllm.distributed.ec_transfer.ec_connector.
shared_storag
e_connector"
,
"vllm.distributed.ec_transfer.ec_connector.
exampl
e_connector"
,
"EC
SharedStorag
eConnector"
,
"EC
Exampl
eConnector"
,
)
)
vllm/distributed/eplb/__init__.py
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
"""Expert parallelism load balancer (EPLB)."""
Expert parallelism load balancer (EPLB).
"""
from
.eplb_state
import
*
from
.rebalance_algo
import
*
vllm/distributed/eplb/eplb_state.py
View file @
8d75f22e
...
@@ -45,7 +45,7 @@ from vllm.logger import init_logger
...
@@ -45,7 +45,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.models.interfaces
import
MixtureOfExperts
from
vllm.model_executor.models.interfaces
import
MixtureOfExperts
from
.async_worker
import
start_async_worker
from
.async_worker
import
start_async_worker
from
.
rebalance_algo
import
rebalance_experts
from
.
policy
import
EPLB_POLICIES
,
AbstractEplbPolicy
,
DefaultEplbPolicy
from
.rebalance_execute
import
move_from_buffer
,
rearrange_expert_weights_inplace
from
.rebalance_execute
import
move_from_buffer
,
rearrange_expert_weights_inplace
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -213,18 +213,23 @@ class EplbState:
...
@@ -213,18 +213,23 @@ class EplbState:
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
device
=
device
self
.
device
=
device
self
.
model_states
:
dict
[
str
,
EplbModelState
]
=
{}
self
.
model_states
:
dict
[
str
,
EplbModelState
]
=
{}
self
.
policy
:
type
[
AbstractEplbPolicy
]
=
DefaultEplbPolicy
"""
Selected EPLB algorithm class
"""
self
.
expert_load_window_step
:
int
=
0
"""
"""
Current step in the sliding window.
Current step in the sliding window.
Different from `expert_rearrangement_step`,
Different from `expert_rearrangement_step`,
each EP rank may have its own `expert_load_window_step`.
each EP rank may have its own `expert_load_window_step`.
"""
"""
self
.
expert_load_window_s
tep
:
int
=
0
self
.
expert_load_window_s
ize
:
int
=
0
"""
"""
Size of the expert load sliding window.
Size of the expert load sliding window.
This is a constant and is taken from the config.
This is a constant and is taken from the config.
"""
"""
self
.
expert_
load_window_size
:
int
=
0
self
.
expert_
rearrangement_step
:
int
=
0
"""
"""
Steps after last rearrangement.
Steps after last rearrangement.
Will trigger a rearrangement if it exceeds the threshold.
Will trigger a rearrangement if it exceeds the threshold.
...
@@ -415,6 +420,10 @@ class EplbState:
...
@@ -415,6 +420,10 @@ class EplbState:
)
)
self
.
expert_rearrangement_step_interval
=
eplb_step_interval
self
.
expert_rearrangement_step_interval
=
eplb_step_interval
# Set the policy based on the selected eplb algorithm type.
policy_type
=
self
.
parallel_config
.
eplb_config
.
policy
self
.
policy
=
EPLB_POLICIES
[
policy_type
]
logger
.
debug
(
"Selected EPLB policy: %d"
,
policy_type
)
if
global_expert_load
is
not
None
:
if
global_expert_load
is
not
None
:
ep_group
=
get_ep_group
().
device_group
ep_group
=
get_ep_group
().
device_group
assert
global_expert_load
.
shape
==
(
assert
global_expert_load
.
shape
==
(
...
@@ -441,7 +450,7 @@ class EplbState:
...
@@ -441,7 +450,7 @@ class EplbState:
new_physical_to_logical_map
,
new_physical_to_logical_map
,
new_logical_to_physical_map
,
new_logical_to_physical_map
,
new_logical_replica_count
,
new_logical_replica_count
,
)
=
rebalance_experts
(
)
=
self
.
policy
.
rebalance_experts
(
global_expert_load
,
global_expert_load
,
num_replicas
,
num_replicas
,
num_groups
,
num_groups
,
...
@@ -776,6 +785,7 @@ class EplbState:
...
@@ -776,6 +785,7 @@ class EplbState:
f
"
{
num_gpus
=
}
,
{
num_nodes
=
}
"
f
"
{
num_gpus
=
}
,
{
num_nodes
=
}
"
)
)
# Get new expert mappings
for
eplb_model_state
,
global_expert_load_window
in
zip
(
for
eplb_model_state
,
global_expert_load_window
in
zip
(
self
.
model_states
.
values
(),
global_expert_load_windows
self
.
model_states
.
values
(),
global_expert_load_windows
):
):
...
@@ -784,7 +794,7 @@ class EplbState:
...
@@ -784,7 +794,7 @@ class EplbState:
new_physical_to_logical_map
,
new_physical_to_logical_map
,
new_logical_to_physical_map
,
new_logical_to_physical_map
,
new_logical_replica_count
,
new_logical_replica_count
,
)
=
rebalance_experts
(
)
=
self
.
policy
.
rebalance_experts
(
global_expert_load_window
,
global_expert_load_window
,
num_replicas
,
num_replicas
,
num_groups
,
num_groups
,
...
...
vllm/distributed/eplb/policy/__init__.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
get_args
from
vllm.config.parallel
import
EPLBPolicyOption
from
.abstract
import
AbstractEplbPolicy
from
.default
import
DefaultEplbPolicy
EPLB_POLICIES
=
{
"default"
:
DefaultEplbPolicy
}
# Ensure that the EPLB_POLICIES keys match the EPLBPolicyOption values
assert
set
(
EPLB_POLICIES
.
keys
())
==
set
(
get_args
(
EPLBPolicyOption
))
__all__
=
[
"AbstractEplbPolicy"
,
"DefaultEplbPolicy"
,
"EPLB_POLICIES"
,
]
vllm/distributed/eplb/policy/abstract.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
import
torch
class
AbstractEplbPolicy
(
ABC
):
@
classmethod
@
abstractmethod
def
rebalance_experts
(
cls
,
weight
:
torch
.
Tensor
,
num_replicas
:
int
,
num_groups
:
int
,
num_nodes
:
int
,
num_ranks
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics
for all logical experts
num_replicas: number of physical experts, must be a multiple of
`num_ranks`
num_groups: number of expert groups
num_nodes: number of server nodes
num_ranks: number of ranks, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map: [layers, num_replicas], the expert
index of each replica
logical_to_physical_map: [layers, num_logical_experts, X],
the replica indices for each expert
expert_count: [layers, num_logical_experts], number of
physical replicas for each logical expert
"""
raise
NotImplementedError
vllm/distributed/eplb/policy/default.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Expert parallelism load balancer (EPLB) for vLLM.
This module implements the core rearrangement algorithm.
The rearrangement algorithm is adapted from
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
on how the EPLB algorithm works.
"""
import
numpy
as
np
import
torch
from
.abstract
import
AbstractEplbPolicy
class
DefaultEplbPolicy
(
AbstractEplbPolicy
):
@
classmethod
def
balanced_packing
(
cls
,
weight
:
torch
.
Tensor
,
num_packs
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly
n/m objects and the weights of all packs are as balanced as possible.
Parameters:
weight: [X, n], the weight of each item
num_packs: number of packs
Returns:
pack_index: [X, n], the pack index of each item
rank_in_pack: [X, n], the rank of the item in the pack
"""
num_layers
,
num_groups
=
weight
.
shape
assert
num_groups
%
num_packs
==
0
groups_per_pack
=
num_groups
//
num_packs
device
=
weight
.
device
if
groups_per_pack
==
1
:
pack_index
=
torch
.
arange
(
weight
.
size
(
-
1
),
dtype
=
torch
.
int64
,
device
=
device
).
expand
(
weight
.
shape
)
rank_in_pack
=
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
int64
,
device
=
device
)
return
pack_index
,
rank_in_pack
weight_np
=
weight
.
cpu
().
numpy
()
# Sort and get indices in decending order
indices_np
=
np
.
argsort
(
-
weight_np
,
axis
=-
1
)
pack_index_np
=
np
.
full
((
num_layers
,
num_groups
),
-
1
,
dtype
=
np
.
int64
)
rank_in_pack_np
=
np
.
full
((
num_layers
,
num_groups
),
-
1
,
dtype
=
np
.
int64
)
# Run the packing algorithm
for
i
in
range
(
num_layers
):
pack_weights
=
[
0.0
]
*
num_packs
pack_items
=
[
0
]
*
num_packs
for
group
in
indices_np
[
i
]:
# Find a pack with capacity that has the lowest weight
pack
=
min
(
(
j
for
j
in
range
(
num_packs
)
if
pack_items
[
j
]
<
groups_per_pack
),
key
=
pack_weights
.
__getitem__
,
)
assert
pack_items
[
pack
]
<
groups_per_pack
pack_index_np
[
i
,
group
]
=
pack
rank_in_pack_np
[
i
,
group
]
=
pack_items
[
pack
]
pack_weights
[
pack
]
+=
weight_np
[
i
,
group
]
pack_items
[
pack
]
+=
1
pack_index
=
torch
.
from_numpy
(
pack_index_np
).
to
(
device
)
rank_in_pack
=
torch
.
from_numpy
(
rank_in_pack_np
).
to
(
device
)
return
pack_index
,
rank_in_pack
@
classmethod
def
replicate_experts
(
cls
,
weight
:
torch
.
Tensor
,
num_phy
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
load of all replicas is minimized.
Parameters:
weight: [X, num_log]
num_phy: total number of experts after replication
Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank
logcnt: [X, num_log], number of replicas for each logical expert
"""
n
,
num_log
=
weight
.
shape
num_redundant
=
num_phy
-
num_log
assert
num_redundant
>=
0
device
=
weight
.
device
phy2log
=
torch
.
arange
(
num_phy
,
dtype
=
torch
.
int64
,
device
=
device
).
repeat
(
n
,
1
)
rank
=
torch
.
zeros
(
n
,
num_phy
,
dtype
=
torch
.
int64
,
device
=
device
)
logcnt
=
torch
.
ones
(
n
,
num_log
,
dtype
=
torch
.
int64
,
device
=
device
)
arangen
=
torch
.
arange
(
n
,
dtype
=
torch
.
int64
,
device
=
device
)
for
i
in
range
(
num_log
,
num_phy
):
redundant_indices
=
(
weight
/
logcnt
).
max
(
dim
=-
1
).
indices
phy2log
[:,
i
]
=
redundant_indices
rank
[:,
i
]
=
logcnt
[
arangen
,
redundant_indices
]
logcnt
[
arangen
,
redundant_indices
]
+=
1
return
phy2log
,
rank
,
logcnt
@
classmethod
def
rebalance_experts_hierarchical
(
cls
,
weight
:
torch
.
Tensor
,
num_physical_experts
:
int
,
num_groups
:
int
,
num_nodes
:
int
,
num_gpus
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
num_physical_experts: number of physical experts after replication
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
phy2log: [layers, num_replicas], the expert
index of each replica
log2phy: [layers, num_logical_experts, X],
the replica indices for each expert
logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert
"""
num_layers
,
num_logical_experts
=
weight
.
shape
assert
num_logical_experts
%
num_groups
==
0
group_size
=
num_logical_experts
//
num_groups
assert
num_groups
%
num_nodes
==
0
groups_per_node
=
num_groups
//
num_nodes
assert
num_gpus
%
num_nodes
==
0
assert
num_physical_experts
%
num_gpus
==
0
phy_experts_per_gpu
=
num_physical_experts
//
num_gpus
def
inverse
(
perm
:
torch
.
Tensor
)
->
torch
.
Tensor
:
inv
=
torch
.
empty_like
(
perm
)
inv
.
scatter_
(
1
,
perm
,
torch
.
arange
(
perm
.
size
(
1
),
dtype
=
torch
.
int64
,
device
=
perm
.
device
).
expand
(
perm
.
shape
),
)
return
inv
# Step 1: pack groups to nodes
tokens_per_group
=
weight
.
unflatten
(
-
1
,
(
num_groups
,
group_size
)).
sum
(
-
1
)
group_pack_index
,
group_rank_in_pack
=
cls
.
balanced_packing
(
tokens_per_group
,
num_nodes
)
log2mlog
=
(
(
(
group_pack_index
*
groups_per_node
+
group_rank_in_pack
)
*
group_size
).
unsqueeze
(
-
1
)
+
torch
.
arange
(
group_size
,
dtype
=
torch
.
int64
,
device
=
group_pack_index
.
device
)
).
flatten
(
-
2
)
mlog2log
=
inverse
(
log2mlog
)
# Step 2: construct redundant experts within nodes
# [num_layers * num_nodes, num_logical_experts // num_nodes]
tokens_per_mlog
=
weight
.
gather
(
-
1
,
mlog2log
).
view
(
-
1
,
num_logical_experts
//
num_nodes
)
phy2mlog
,
phyrank
,
mlogcnt
=
cls
.
replicate_experts
(
tokens_per_mlog
,
num_physical_experts
//
num_nodes
)
# Step 3: pack physical_experts to GPUs
# [num_layers * num_nodes, num_physical_experts // num_nodes]
tokens_per_phy
=
(
tokens_per_mlog
/
mlogcnt
).
gather
(
-
1
,
phy2mlog
)
pack_index
,
rank_in_pack
=
cls
.
balanced_packing
(
tokens_per_phy
,
num_gpus
//
num_nodes
)
phy2pphy
=
pack_index
*
phy_experts_per_gpu
+
rank_in_pack
pphy2phy
=
inverse
(
phy2pphy
)
pphy2mlog
=
phy2mlog
.
gather
(
-
1
,
pphy2phy
)
# [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog
=
(
pphy2mlog
.
view
(
num_layers
,
num_nodes
,
-
1
)
+
torch
.
arange
(
0
,
num_logical_experts
,
num_logical_experts
//
num_nodes
,
device
=
group_pack_index
.
device
,
).
view
(
1
,
-
1
,
1
)
).
flatten
(
-
2
)
pphy2log
=
mlog2log
.
gather
(
-
1
,
pphy2mlog
)
pphyrank
=
phyrank
.
gather
(
-
1
,
pphy2phy
).
view
(
num_layers
,
-
1
)
logcnt
=
mlogcnt
.
view
(
num_layers
,
-
1
).
gather
(
-
1
,
log2mlog
)
return
pphy2log
,
pphyrank
,
logcnt
@
classmethod
def
rebalance_experts
(
cls
,
weight
:
torch
.
Tensor
,
num_replicas
:
int
,
num_groups
:
int
,
num_nodes
:
int
,
num_ranks
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics for all
logical experts
num_replicas: number of physical experts, must be a multiple of
`num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_ranks: number of ranks, must be a multiple of `num_nodes`
Returns:
phy2log: [layers, num_replicas], the expert
index of each replica
log2phy: [layers, num_logical_experts, X],
the replica indices for each expert
logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert
"""
num_layers
,
num_logical_experts
=
weight
.
shape
weight
=
weight
.
float
()
if
num_groups
%
num_nodes
==
0
:
# use hierarchical load-balance policy
phy2log
,
phyrank
,
logcnt
=
cls
.
rebalance_experts_hierarchical
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_ranks
)
else
:
# use global load-balance policy
phy2log
,
phyrank
,
logcnt
=
cls
.
rebalance_experts_hierarchical
(
weight
,
num_replicas
,
1
,
1
,
num_ranks
)
num_redundant_experts
=
num_replicas
-
num_logical_experts
maxlogcnt
=
num_redundant_experts
+
1
log2phy
:
torch
.
Tensor
=
torch
.
full
(
(
num_layers
,
num_logical_experts
,
maxlogcnt
),
-
1
,
dtype
=
torch
.
int64
,
device
=
logcnt
.
device
,
)
log2phy
.
view
(
num_layers
,
-
1
).
scatter_
(
-
1
,
phy2log
*
maxlogcnt
+
phyrank
,
torch
.
arange
(
num_replicas
,
dtype
=
torch
.
int64
,
device
=
log2phy
.
device
).
expand
(
num_layers
,
-
1
),
)
return
phy2log
,
log2phy
,
logcnt
vllm/distributed/eplb/rebalance_algo.py
deleted
100644 → 0
View file @
ce888aa4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Expert parallelism load balancer (EPLB) for vLLM.
This module implements the core rearrangement algorithm.
The rearrangement algorithm is adapted from
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
on how the EPLB algorithm works.
"""
import
numpy
as
np
import
torch
def
balanced_packing
(
weight
:
torch
.
Tensor
,
num_packs
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly
n/m objects and the weights of all packs are as balanced as possible.
Parameters:
weight: [X, n], the weight of each item
num_packs: number of packs
Returns:
pack_index: [X, n], the pack index of each item
rank_in_pack: [X, n], the rank of the item in the pack
"""
num_layers
,
num_groups
=
weight
.
shape
assert
num_groups
%
num_packs
==
0
groups_per_pack
=
num_groups
//
num_packs
device
=
weight
.
device
if
groups_per_pack
==
1
:
pack_index
=
torch
.
arange
(
weight
.
size
(
-
1
),
dtype
=
torch
.
int64
,
device
=
device
).
expand
(
weight
.
shape
)
rank_in_pack
=
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
int64
,
device
=
device
)
return
pack_index
,
rank_in_pack
weight_np
=
weight
.
cpu
().
numpy
()
# Sort and get indices in decending order
indices_np
=
np
.
argsort
(
-
weight_np
,
axis
=-
1
)
pack_index_np
=
np
.
full
((
num_layers
,
num_groups
),
-
1
,
dtype
=
np
.
int64
)
rank_in_pack_np
=
np
.
full
((
num_layers
,
num_groups
),
-
1
,
dtype
=
np
.
int64
)
# Run the packing algorithm
for
i
in
range
(
num_layers
):
pack_weights
=
[
0.0
]
*
num_packs
pack_items
=
[
0
]
*
num_packs
for
group
in
indices_np
[
i
]:
# Find a pack with capacity that has the lowest weight
pack
=
min
(
(
j
for
j
in
range
(
num_packs
)
if
pack_items
[
j
]
<
groups_per_pack
),
key
=
pack_weights
.
__getitem__
,
)
assert
pack_items
[
pack
]
<
groups_per_pack
pack_index_np
[
i
,
group
]
=
pack
rank_in_pack_np
[
i
,
group
]
=
pack_items
[
pack
]
pack_weights
[
pack
]
+=
weight_np
[
i
,
group
]
pack_items
[
pack
]
+=
1
pack_index
=
torch
.
from_numpy
(
pack_index_np
).
to
(
device
)
rank_in_pack
=
torch
.
from_numpy
(
rank_in_pack_np
).
to
(
device
)
return
pack_index
,
rank_in_pack
def
replicate_experts
(
weight
:
torch
.
Tensor
,
num_phy
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
load of all replicas is minimized.
Parameters:
weight: [X, num_log]
num_phy: total number of experts after replication
Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank
logcnt: [X, num_log], number of replicas for each logical expert
"""
n
,
num_log
=
weight
.
shape
num_redundant
=
num_phy
-
num_log
assert
num_redundant
>=
0
device
=
weight
.
device
phy2log
=
torch
.
arange
(
num_phy
,
dtype
=
torch
.
int64
,
device
=
device
).
repeat
(
n
,
1
)
rank
=
torch
.
zeros
(
n
,
num_phy
,
dtype
=
torch
.
int64
,
device
=
device
)
logcnt
=
torch
.
ones
(
n
,
num_log
,
dtype
=
torch
.
int64
,
device
=
device
)
arangen
=
torch
.
arange
(
n
,
dtype
=
torch
.
int64
,
device
=
device
)
for
i
in
range
(
num_log
,
num_phy
):
redundant_indices
=
(
weight
/
logcnt
).
max
(
dim
=-
1
).
indices
phy2log
[:,
i
]
=
redundant_indices
rank
[:,
i
]
=
logcnt
[
arangen
,
redundant_indices
]
logcnt
[
arangen
,
redundant_indices
]
+=
1
return
phy2log
,
rank
,
logcnt
def
rebalance_experts_hierarchical
(
weight
:
torch
.
Tensor
,
num_physical_experts
:
int
,
num_groups
:
int
,
num_nodes
:
int
,
num_gpus
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
num_physical_experts: number of physical experts after replication
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g., NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map (torch.Tensor):
[num_moe_layers, num_physical_experts]
logical_to_physical_map (torch.Tensor):
[num_moe_layers, num_logical_experts, X]
logical_count (torch.Tensor):
[num_moe_layers, num_logical_experts]
"""
num_layers
,
num_logical_experts
=
weight
.
shape
assert
num_logical_experts
%
num_groups
==
0
group_size
=
num_logical_experts
//
num_groups
assert
num_groups
%
num_nodes
==
0
groups_per_node
=
num_groups
//
num_nodes
assert
num_gpus
%
num_nodes
==
0
assert
num_physical_experts
%
num_gpus
==
0
phy_experts_per_gpu
=
num_physical_experts
//
num_gpus
def
inverse
(
perm
:
torch
.
Tensor
)
->
torch
.
Tensor
:
inv
=
torch
.
empty_like
(
perm
)
inv
.
scatter_
(
1
,
perm
,
torch
.
arange
(
perm
.
size
(
1
),
dtype
=
torch
.
int64
,
device
=
perm
.
device
).
expand
(
perm
.
shape
),
)
return
inv
# Step 1: pack groups to nodes
tokens_per_group
=
weight
.
unflatten
(
-
1
,
(
num_groups
,
group_size
)).
sum
(
-
1
)
group_pack_index
,
group_rank_in_pack
=
balanced_packing
(
tokens_per_group
,
num_nodes
)
log2mlog
=
(
(
(
group_pack_index
*
groups_per_node
+
group_rank_in_pack
)
*
group_size
).
unsqueeze
(
-
1
)
+
torch
.
arange
(
group_size
,
dtype
=
torch
.
int64
,
device
=
group_pack_index
.
device
)
).
flatten
(
-
2
)
mlog2log
=
inverse
(
log2mlog
)
# Step 2: construct redundant experts within nodes
# [num_layers * num_nodes, num_logical_experts // num_nodes]
tokens_per_mlog
=
weight
.
gather
(
-
1
,
mlog2log
).
view
(
-
1
,
num_logical_experts
//
num_nodes
)
phy2mlog
,
phyrank
,
mlogcnt
=
replicate_experts
(
tokens_per_mlog
,
num_physical_experts
//
num_nodes
)
# Step 3: pack physical_experts to GPUs
# [num_layers * num_nodes, num_physical_experts // num_nodes]
tokens_per_phy
=
(
tokens_per_mlog
/
mlogcnt
).
gather
(
-
1
,
phy2mlog
)
pack_index
,
rank_in_pack
=
balanced_packing
(
tokens_per_phy
,
num_gpus
//
num_nodes
)
phy2pphy
=
pack_index
*
phy_experts_per_gpu
+
rank_in_pack
pphy2phy
=
inverse
(
phy2pphy
)
pphy2mlog
=
phy2mlog
.
gather
(
-
1
,
pphy2phy
)
# [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog
=
(
pphy2mlog
.
view
(
num_layers
,
num_nodes
,
-
1
)
+
torch
.
arange
(
0
,
num_logical_experts
,
num_logical_experts
//
num_nodes
,
device
=
group_pack_index
.
device
,
).
view
(
1
,
-
1
,
1
)
).
flatten
(
-
2
)
pphy2log
=
mlog2log
.
gather
(
-
1
,
pphy2mlog
)
pphyrank
=
phyrank
.
gather
(
-
1
,
pphy2phy
).
view
(
num_layers
,
-
1
)
logcnt
=
mlogcnt
.
view
(
num_layers
,
-
1
).
gather
(
-
1
,
log2mlog
)
return
pphy2log
,
pphyrank
,
logcnt
def
rebalance_experts
(
weight
:
torch
.
Tensor
,
num_replicas
:
int
,
num_groups
:
int
,
num_nodes
:
int
,
num_gpus
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics for all
logical experts
num_replicas: number of physical experts, must be a multiple of
`num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map:
[layers, num_replicas], the expert index of each replica
logical_to_physical_map:
[layers, num_logical_experts, X], the replica indices for each
expert
expert_count:
[layers, num_logical_experts], number of physical
replicas for each logical expert
"""
num_layers
,
num_logical_experts
=
weight
.
shape
weight
=
weight
.
float
()
if
num_groups
%
num_nodes
==
0
:
# use hierarchical load-balance policy
phy2log
,
phyrank
,
logcnt
=
rebalance_experts_hierarchical
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
else
:
# use global load-balance policy
phy2log
,
phyrank
,
logcnt
=
rebalance_experts_hierarchical
(
weight
,
num_replicas
,
1
,
1
,
num_gpus
)
num_redundant_experts
=
num_replicas
-
num_logical_experts
maxlogcnt
=
num_redundant_experts
+
1
log2phy
:
torch
.
Tensor
=
torch
.
full
(
(
num_layers
,
num_logical_experts
,
maxlogcnt
),
-
1
,
dtype
=
torch
.
int64
,
device
=
logcnt
.
device
,
)
log2phy
.
view
(
num_layers
,
-
1
).
scatter_
(
-
1
,
phy2log
*
maxlogcnt
+
phyrank
,
torch
.
arange
(
num_replicas
,
dtype
=
torch
.
int64
,
device
=
log2phy
.
device
).
expand
(
num_layers
,
-
1
),
)
return
phy2log
,
log2phy
,
logcnt
__all__
=
[
"rebalance_experts"
]
vllm/distributed/eplb/rebalance_execute.py
View file @
8d75f22e
...
@@ -322,9 +322,6 @@ async def transfer_layer(
...
@@ -322,9 +322,6 @@ async def transfer_layer(
num_local_physical_experts
=
next
(
iter
(
expert_weights
[
0
])).
shape
[
0
]
num_local_physical_experts
=
next
(
iter
(
expert_weights
[
0
])).
shape
[
0
]
assert
new_global_expert_indices
.
shape
==
(
num_moe_layers
,
num_physical_experts
)
assert
new_global_expert_indices
.
shape
==
(
num_moe_layers
,
num_physical_experts
)
assert
num_physical_experts
==
ep_size
*
num_local_physical_experts
assert
num_physical_experts
==
ep_size
*
num_local_physical_experts
# A buffer to hold the expert weights in one layer during the exchange.
# NOTE: Currently we assume the same weights across different layers
# have the same shape.
is_unchanged
,
is_received_locally
,
experts_recv_loc
=
move_to_buffer
(
is_unchanged
,
is_received_locally
,
experts_recv_loc
=
move_to_buffer
(
num_local_experts
=
num_local_physical_experts
,
num_local_experts
=
num_local_physical_experts
,
...
...
vllm/distributed/kv_transfer/kv_connector/factory.py
View file @
8d75f22e
...
@@ -144,9 +144,9 @@ class KVConnectorFactory:
...
@@ -144,9 +144,9 @@ class KVConnectorFactory:
# only load the files corresponding to the current connector.
# only load the files corresponding to the current connector.
KVConnectorFactory
.
register_connector
(
KVConnectorFactory
.
register_connector
(
"
SharedStorag
eConnector"
,
"
Exampl
eConnector"
,
"vllm.distributed.kv_transfer.kv_connector.v1.
shared_storag
e_connector"
,
"vllm.distributed.kv_transfer.kv_connector.v1.
exampl
e_connector"
,
"
SharedStorag
eConnector"
,
"
Exampl
eConnector"
,
)
)
KVConnectorFactory
.
register_connector
(
KVConnectorFactory
.
register_connector
(
...
@@ -190,3 +190,8 @@ KVConnectorFactory.register_connector(
...
@@ -190,3 +190,8 @@ KVConnectorFactory.register_connector(
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector"
,
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector"
,
"DecodeBenchConnector"
,
"DecodeBenchConnector"
,
)
)
KVConnectorFactory
.
register_connector
(
"MooncakeConnector"
,
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector"
,
"MooncakeConnector"
,
)
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
8d75f22e
...
@@ -4,13 +4,14 @@
...
@@ -4,13 +4,14 @@
KV cache helper for store.
KV cache helper for store.
"""
"""
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Literal
from
typing
import
TYPE_CHECKING
,
Literal
import
torch
import
torch
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm
import
_custom_ops
as
ops
from
vllm
.attention.backends.registry
import
AttentionBackendEnum
from
vllm.config
import
VllmConfig
,
get_current_vllm_config
from
vllm.config
import
get_current_vllm_config
from
vllm.distributed.kv_transfer.kv_connector.factory
import
KVConnectorFactory
from
vllm.distributed.kv_transfer.kv_connector.factory
import
KVConnectorFactory
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
...
@@ -21,89 +22,6 @@ if TYPE_CHECKING:
...
@@ -21,89 +22,6 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
model_aware_kv_ops_helper
:
def
__init__
(
self
,
config
:
VllmConfig
):
self
.
is_deepseek_mla
=
config
.
model_config
.
is_deepseek_mla
self
.
use_mla_opt
=
not
envs
.
VLLM_MLA_DISABLE
self
.
tp_size
=
config
.
parallel_config
.
tensor_parallel_size
def
get_model_args
(
self
,
model_executable
:
torch
.
nn
.
Module
):
model_config
=
model_executable
.
model
.
config
self
.
model_executable
=
model_executable
num_heads
=
int
(
model_config
.
num_key_value_heads
/
self
.
tp_size
)
hidden_size
=
model_config
.
hidden_size
num_attention_heads
=
model_config
.
num_attention_heads
# Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/v1/attention/backends/mla/common.py.
if
self
.
is_deepseek_mla
and
self
.
use_mla_opt
:
head_size
=
model_config
.
kv_lora_rank
+
model_config
.
qk_rope_head_dim
num_heads
=
1
elif
self
.
is_deepseek_mla
and
not
self
.
use_mla_opt
:
head_size
=
model_config
.
qk_nope_head_dim
+
model_config
.
qk_rope_head_dim
else
:
head_size
=
getattr
(
model_config
,
"head_dim"
,
None
)
if
head_size
is
None
:
head_size
=
int
(
hidden_size
//
num_attention_heads
)
return
num_heads
,
head_size
def
get_kv_from_cache
(
self
,
kv_cache
,
num_heads
,
head_size
):
if
self
.
is_deepseek_mla
and
self
.
use_mla_opt
:
key_cache
=
kv_cache
.
reshape
(
-
1
,
num_heads
,
head_size
)
value_cache
=
kv_cache
.
reshape
(
-
1
,
num_heads
,
head_size
)
else
:
key_cache
=
kv_cache
[
0
].
reshape
(
-
1
,
num_heads
,
head_size
)
value_cache
=
kv_cache
[
1
].
reshape
(
-
1
,
num_heads
,
head_size
)
return
key_cache
,
value_cache
def
put_kv_to_cache
(
self
,
model_executable
:
torch
.
nn
.
Module
,
keys
,
values
,
layer
,
kv_cache
,
slot_mapping
,
start_pos
,
end_pos
,
):
model_config
=
model_executable
.
model
.
config
if
self
.
is_deepseek_mla
and
self
.
use_mla_opt
:
layer
.
self_attn
.
attn
=
layer
.
self_attn
.
mla_attn
k_c_normed_k_pe
=
keys
.
squeeze
(
1
)
k_c_normed
=
k_c_normed_k_pe
[:,
:
model_config
.
kv_lora_rank
]
k_pe
=
k_c_normed_k_pe
[:,
model_config
.
kv_lora_rank
:]
ops
.
concat_and_cache_mla
(
k_c_normed
.
to
(
kv_cache
.
device
),
k_pe
.
to
(
kv_cache
.
device
),
kv_cache
,
slot_mapping
[
start_pos
:
end_pos
],
layer
.
self_attn
.
attn
.
kv_cache_dtype
,
layer
.
self_attn
.
attn
.
_k_scale
,
)
else
:
key_cache
,
value_cache
=
kv_cache
[
0
],
kv_cache
[
1
]
ops
.
reshape_and_cache_flash
(
keys
.
to
(
key_cache
.
device
),
values
.
to
(
value_cache
.
device
),
key_cache
,
value_cache
,
slot_mapping
[
start_pos
:
end_pos
],
layer
.
self_attn
.
attn
.
kv_cache_dtype
,
layer
.
self_attn
.
attn
.
_k_scale
,
layer
.
self_attn
.
attn
.
_v_scale
,
)
def
get_kv_connector_cache_layout
():
def
get_kv_connector_cache_layout
():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
# used for faster transfer.
# used for faster transfer.
...
@@ -266,3 +184,124 @@ def copy_kv_blocks(
...
@@ -266,3 +184,124 @@ def copy_kv_blocks(
src_tensor
=
src_kv_caches
[
layer_name
]
src_tensor
=
src_kv_caches
[
layer_name
]
dst_tensor
=
dst_kv_caches
[
layer_name
]
dst_tensor
=
dst_kv_caches
[
layer_name
]
copy_fn
(
src_tensor
,
dst_tensor
,
src_indices
,
dst_indices
)
copy_fn
(
src_tensor
,
dst_tensor
,
src_indices
,
dst_indices
)
@
dataclass
class
TpKVTopology
:
"""
Helper class for tensor parallel and KV topology information for
mapping between local and remote TP workers.
"""
tp_rank
:
int
remote_tp_size
:
dict
[
str
,
int
]
is_mla
:
bool
total_num_kv_heads
:
int
attn_backend
:
type
[
AttentionBackend
]
engine_id
:
str
remote_block_size
:
dict
[
str
,
int
]
def
__post_init__
(
self
):
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
=
1
,
block_size
=
16
,
num_kv_heads
=
1
,
head_size
=
1
)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below.
self
.
_is_kv_layout_blocks_first
=
(
len
(
kv_cache_shape
)
==
5
and
kv_cache_shape
[
0
]
==
1
)
attn_backend
=
AttentionBackendEnum
[
self
.
attn_backend
.
get_name
()]
self
.
_use_pallas
=
attn_backend
==
AttentionBackendEnum
.
PALLAS
@
property
def
is_kv_layout_blocks_first
(
self
)
->
bool
:
return
self
.
_is_kv_layout_blocks_first
@
property
def
split_k_and_v
(
self
)
->
bool
:
# Whether to register regions for K and V separately (when present).
return
not
(
self
.
is_mla
or
self
.
_use_pallas
or
self
.
is_kv_layout_blocks_first
)
@
property
def
tp_size
(
self
)
->
int
:
return
self
.
remote_tp_size
[
self
.
engine_id
]
@
property
def
block_size
(
self
)
->
int
:
return
self
.
remote_block_size
[
self
.
engine_id
]
def
tp_ratio
(
self
,
remote_tp_size
:
int
,
)
->
int
:
"""
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
"""
assert
self
.
tp_size
%
remote_tp_size
==
0
,
(
f
"Local tensor parallel size
{
self
.
tp_size
}
is not divisible "
f
"by remote tensor parallel size
{
remote_tp_size
}
."
)
return
self
.
tp_size
//
remote_tp_size
def
block_size_ratio
(
self
,
remote_block_size
:
int
,
)
->
float
:
"""
Calculate the block size ratio between local and remote TP.
"""
assert
self
.
block_size
%
remote_block_size
==
0
,
(
f
"Local block size
{
self
.
block_size
}
is not divisible "
f
"by remote block size
{
remote_block_size
}
or vice versa."
)
return
self
.
block_size
//
remote_block_size
def
tp_ratio_from_engine_id
(
self
,
remote_engine_id
:
str
,
)
->
int
:
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
return
self
.
tp_ratio
(
remote_tp_size
)
def
block_size_ratio_from_engine_id
(
self
,
remote_engine_id
:
str
,
)
->
float
:
remote_block_size
=
self
.
remote_block_size
[
remote_engine_id
]
return
self
.
block_size_ratio
(
remote_block_size
)
def
is_kv_replicated
(
self
,
engine_id
:
str
)
->
bool
:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
"""
tp_size
=
self
.
remote_tp_size
[
engine_id
]
return
tp_size
//
self
.
total_num_kv_heads
>=
1
def
replicates_kv_cache
(
self
,
remote_engine_id
:
str
)
->
bool
:
# MLA is always replicated as the hidden dim can't be split.
return
self
.
is_mla
or
self
.
is_kv_replicated
(
remote_engine_id
)
def
get_target_remote_rank
(
self
,
remote_tp_size
:
int
,
)
->
int
:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
"""
tp_ratio
=
self
.
tp_ratio
(
remote_tp_size
)
return
self
.
tp_rank
//
tp_ratio
def
get_target_remote_rank_from_engine_id
(
self
,
remote_engine_id
:
str
,
)
->
int
:
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
return
self
.
get_target_remote_rank
(
remote_tp_size
)
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
8d75f22e
...
@@ -239,7 +239,7 @@ class KVConnectorBase_V1(ABC):
...
@@ -239,7 +239,7 @@ class KVConnectorBase_V1(ABC):
return
return
def
register_cross_layers_kv_cache
(
def
register_cross_layers_kv_cache
(
self
,
kv_cache
:
torch
.
Tensor
,
attn_backend
:
type
[
AttentionBackend
]
self
,
kv_cache
:
torch
.
Tensor
,
attn_backend
:
type
[
"
AttentionBackend
"
]
):
):
"""
"""
Initialize with a single KV cache tensor used by all layers.
Initialize with a single KV cache tensor used by all layers.
...
@@ -573,3 +573,17 @@ class KVConnectorBase_V1(ABC):
...
@@ -573,3 +573,17 @@ class KVConnectorBase_V1(ABC):
expose connector transfer stats via Prometheus.
expose connector transfer stats via Prometheus.
"""
"""
return
None
return
None
def
reset_cache
(
self
)
->
bool
|
None
:
"""
Reset the connector's internal cache.
Returns:
bool: True if the cache was successfully reset, False otherwise.
"""
logger
.
debug
(
"Connector cache reset requested, but %s does not implement reset_cache()."
,
type
(
self
).
__name__
,
)
return
None
vllm/distributed/kv_transfer/kv_connector/v1/
shared_storag
e_connector.py
→
vllm/distributed/kv_transfer/kv_connector/v1/
exampl
e_connector.py
View file @
8d75f22e
...
@@ -65,7 +65,7 @@ class ReqMeta:
...
@@ -65,7 +65,7 @@ class ReqMeta:
@
dataclass
@
dataclass
class
SharedStorag
eConnectorMetadata
(
KVConnectorMetadata
):
class
Exampl
eConnectorMetadata
(
KVConnectorMetadata
):
requests
:
list
[
ReqMeta
]
=
field
(
default_factory
=
list
)
requests
:
list
[
ReqMeta
]
=
field
(
default_factory
=
list
)
def
add_request
(
def
add_request
(
...
@@ -81,7 +81,7 @@ class SharedStorageConnectorMetadata(KVConnectorMetadata):
...
@@ -81,7 +81,7 @@ class SharedStorageConnectorMetadata(KVConnectorMetadata):
)
)
class
SharedStorag
eConnector
(
KVConnectorBase_V1
):
class
Exampl
eConnector
(
KVConnectorBase_V1
):
# NOTE: This is Simple debug implementation of the KV connector.
# NOTE: This is Simple debug implementation of the KV connector.
# It save / load the KV cache to / from the disk.
# It save / load the KV cache to / from the disk.
# It does extra work which will overwrite the existing prefix-cache in GPU
# It does extra work which will overwrite the existing prefix-cache in GPU
...
@@ -157,7 +157,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
...
@@ -157,7 +157,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# Get the metadata
# Get the metadata
metadata
:
KVConnectorMetadata
=
self
.
_get_connector_metadata
()
metadata
:
KVConnectorMetadata
=
self
.
_get_connector_metadata
()
assert
isinstance
(
metadata
,
SharedStorag
eConnectorMetadata
)
assert
isinstance
(
metadata
,
Exampl
eConnectorMetadata
)
if
metadata
is
None
:
if
metadata
is
None
:
logger
.
warning
(
logger
.
warning
(
...
@@ -241,7 +241,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
...
@@ -241,7 +241,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
return
layer
.
reshape
(
2
,
num_pages
*
page_size
,
-
1
)[:,
slot_mapping
,
...]
return
layer
.
reshape
(
2
,
num_pages
*
page_size
,
-
1
)[:,
slot_mapping
,
...]
connector_metadata
=
self
.
_get_connector_metadata
()
connector_metadata
=
self
.
_get_connector_metadata
()
assert
isinstance
(
connector_metadata
,
SharedStorag
eConnectorMetadata
)
assert
isinstance
(
connector_metadata
,
Exampl
eConnectorMetadata
)
for
request
in
connector_metadata
.
requests
:
for
request
in
connector_metadata
.
requests
:
if
request
.
is_store
:
if
request
.
is_store
:
filename
=
self
.
_generate_filename_debug
(
filename
=
self
.
_generate_filename_debug
(
...
@@ -315,7 +315,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
...
@@ -315,7 +315,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
Args:
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
scheduler_output (SchedulerOutput): the scheduler output object.
"""
"""
meta
=
SharedStorag
eConnectorMetadata
()
meta
=
Exampl
eConnectorMetadata
()
total_need_load
=
0
total_need_load
=
0
for
new_req
in
scheduler_output
.
scheduled_new_reqs
:
for
new_req
in
scheduler_output
.
scheduled_new_reqs
:
...
...
vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
threading
import
time
import
uuid
from
collections
import
defaultdict
from
concurrent.futures
import
ThreadPoolExecutor
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
msgspec
import
numpy
as
np
import
torch
import
zmq
import
zmq.asyncio
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.utils
import
TpKVTopology
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
,
)
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tp_group
,
)
from
vllm.forward_context
import
ForwardContext
from
vllm.logger
import
init_logger
from
vllm.utils.network_utils
import
get_ip
,
make_zmq_path
,
make_zmq_socket
from
vllm.v1.attention.backends.utils
import
get_kv_cache_layout
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.request
import
RequestStatus
try
:
from
mooncake.engine
import
TransferEngine
except
ImportError
as
e
:
raise
ImportError
(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md "
# noqa: E501
"to run VLLM with MooncakeTransferEngine."
)
from
e
if
TYPE_CHECKING
:
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.request
import
Request
EngineId
=
str
ReqId
=
str
TRANS_DONE
=
b
"trans_done"
TRANS_ERROR
=
b
"trans_error"
logger
=
init_logger
(
__name__
)
class
MooncakeAgentMetadata
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
# required for @cached_property.
dict
=
True
,
):
remote_hostname
:
str
remote_port
:
int
request_ids
:
list
[
ReqId
]
kv_caches_base_addr
:
list
[
int
]
block_ids
:
list
[
list
[
int
]]
@
dataclass
class
RecvReqMeta
:
local_block_ids
:
list
[
int
]
remote_host
:
str
remote_port
:
int
@
dataclass
class
SendBlockMeta
:
local_block_ids
:
list
[
int
]
ready
:
threading
.
Event
expire_time
:
float
=
float
(
"inf"
)
@
dataclass
class
SendReqMeta
:
reqs
:
dict
[
ReqId
,
SendBlockMeta
]
lock
:
threading
.
Lock
@
dataclass
class
FinishedSendReqSet
:
set
:
set
[
ReqId
]
lock
:
threading
.
Lock
@
dataclass
class
FinishedReceiveReqSet
:
set
:
set
[
ReqId
]
lock
:
asyncio
.
Lock
class
MooncakeConnectorMetadata
(
KVConnectorMetadata
):
def
__init__
(
self
):
self
.
reqs_to_recv
:
dict
[
ReqId
,
RecvReqMeta
]
=
{}
self
.
reqs_to_send
:
dict
[
ReqId
,
list
[
int
]]
=
{}
def
add_new_req
(
self
,
request_id
:
ReqId
,
local_block_ids
:
list
[
int
],
kv_transfer_params
:
dict
[
str
,
Any
],
load_remote_cache
:
bool
=
True
,
):
if
load_remote_cache
:
self
.
reqs_to_recv
[
request_id
]
=
RecvReqMeta
(
local_block_ids
=
local_block_ids
,
remote_host
=
kv_transfer_params
[
"remote_host"
],
remote_port
=
kv_transfer_params
[
"remote_port"
],
)
else
:
self
.
reqs_to_send
[
request_id
]
=
local_block_ids
class
MooncakeConnector
(
KVConnectorBase_V1
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
role
:
KVConnectorRole
,
kv_cache_config
:
Optional
[
"KVCacheConfig"
]
=
None
,
):
super
().
__init__
(
vllm_config
,
role
,
kv_cache_config
)
assert
vllm_config
.
kv_transfer_config
is
not
None
assert
vllm_config
.
kv_transfer_config
.
engine_id
is
not
None
self
.
engine_id
:
EngineId
=
vllm_config
.
kv_transfer_config
.
engine_id
if
role
==
KVConnectorRole
.
SCHEDULER
:
self
.
connector_scheduler
:
MooncakeConnectorScheduler
|
None
=
(
MooncakeConnectorScheduler
(
vllm_config
,
self
.
engine_id
)
)
self
.
connector_worker
:
MooncakeConnectorWorker
|
None
=
None
elif
role
==
KVConnectorRole
.
WORKER
:
self
.
connector_scheduler
=
None
self
.
connector_worker
=
MooncakeConnectorWorker
(
vllm_config
,
self
.
engine_id
)
############################################################
# Scheduler Side Methods
############################################################
def
get_num_new_matched_tokens
(
self
,
request
:
"Request"
,
num_computed_tokens
:
int
)
->
tuple
[
int
,
bool
]:
assert
self
.
connector_scheduler
is
not
None
return
self
.
connector_scheduler
.
get_num_new_matched_tokens
(
request
,
num_computed_tokens
)
def
update_state_after_alloc
(
self
,
request
:
"Request"
,
blocks
:
"KVCacheBlocks"
,
num_external_tokens
:
int
):
assert
self
.
connector_scheduler
is
not
None
return
self
.
connector_scheduler
.
update_state_after_alloc
(
request
,
blocks
,
num_external_tokens
)
def
build_connector_meta
(
self
,
scheduler_output
:
SchedulerOutput
,
)
->
KVConnectorMetadata
:
assert
self
.
connector_scheduler
is
not
None
return
self
.
connector_scheduler
.
build_connector_meta
(
scheduler_output
)
def
request_finished
(
self
,
request
:
"Request"
,
block_ids
:
list
[
int
],
)
->
tuple
[
bool
,
dict
[
str
,
Any
]
|
None
]:
assert
self
.
connector_scheduler
is
not
None
return
self
.
connector_scheduler
.
request_finished
(
request
,
block_ids
)
############################################################
# Worker Side Methods
############################################################
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
assert
self
.
connector_worker
is
not
None
self
.
connector_worker
.
register_kv_caches
(
kv_caches
)
def
get_finished
(
self
,
finished_req_ids
:
set
[
str
]
)
->
tuple
[
set
[
str
]
|
None
,
set
[
str
]
|
None
]:
"""Get the finished recving and sending requests."""
assert
self
.
connector_worker
is
not
None
return
self
.
connector_worker
.
get_finished
()
def
start_load_kv
(
self
,
forward_context
:
"ForwardContext"
,
**
kwargs
)
->
None
:
assert
self
.
connector_worker
is
not
None
assert
isinstance
(
self
.
_connector_metadata
,
MooncakeConnectorMetadata
)
self
.
connector_worker
.
start_load_kv
(
self
.
_connector_metadata
)
def
wait_for_layer_load
(
self
,
layer_name
:
str
)
->
None
:
"""MooncakeConnector does not do layerwise saving."""
pass
def
save_kv_layer
(
self
,
layer_name
:
str
,
kv_layer
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
,
)
->
None
:
"""MooncakeConnector does not save explicitly."""
pass
def
wait_for_save
(
self
):
pass
class
MooncakeConnectorScheduler
:
"""Implementation of Scheduler side methods"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_id
:
str
):
self
.
vllm_config
=
vllm_config
self
.
engine_id
:
EngineId
=
engine_id
self
.
side_channel_host
=
get_ip
()
self
.
side_channel_port
=
get_mooncake_side_channel_port
(
vllm_config
)
assert
vllm_config
.
kv_transfer_config
self
.
kv_role
=
vllm_config
.
kv_transfer_config
.
kv_role
logger
.
info
(
"Initializing Mooncake Transfer Engine Scheduler %s"
,
engine_id
)
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self
.
_reqs_need_recv
:
dict
[
ReqId
,
tuple
[
Request
,
list
[
int
]]]
=
{}
self
.
_reqs_need_send
:
dict
[
ReqId
,
list
[
int
]]
=
{}
def
get_num_new_matched_tokens
(
self
,
request
:
"Request"
,
num_computed_tokens
:
int
)
->
tuple
[
int
,
bool
]:
"""
For remote prefill, pull all prompt blocks from remote
asynchronously relative to engine execution.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if the external KV cache tokens will be loaded
asynchronously (between scheduler steps).
"""
params
=
request
.
kv_transfer_params
logger
.
debug
(
"MooncakeConnector get_num_new_matched_tokens: "
"num_computed_tokens=%s, kv_transfer_params=%s"
,
num_computed_tokens
,
params
,
)
if
params
is
not
None
and
params
.
get
(
"do_remote_prefill"
):
# Remote prefill: get all prompt blocks from remote.
token_ids
=
request
.
prompt_token_ids
or
[]
count
=
len
(
token_ids
)
-
num_computed_tokens
if
count
>
0
:
return
count
,
True
# No remote prefill for this request.
return
0
,
False
def
update_state_after_alloc
(
self
,
request
:
"Request"
,
blocks
:
"KVCacheBlocks"
,
num_external_tokens
:
int
):
params
=
request
.
kv_transfer_params
logger
.
debug
(
"MooncakeConnector update_state_after_alloc: "
"num_external_tokens=%s, kv_transfer_params=%s"
,
num_external_tokens
,
params
,
)
if
not
params
:
return
if
params
.
get
(
"do_remote_prefill"
):
assert
self
.
kv_role
!=
"kv_producer"
if
all
(
p
in
params
for
p
in
(
"remote_host"
,
"remote_port"
)):
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
local_block_ids
=
(
blocks
.
get_unhashed_block_ids
()
if
num_external_tokens
>
0
else
[]
)
# Get unhashed blocks to pull from remote.
self
.
_reqs_need_recv
[
request
.
request_id
]
=
(
request
,
local_block_ids
)
else
:
logger
.
warning
(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer"
,
params
,
)
# Only trigger 1 KV transfer per request.
params
[
"do_remote_prefill"
]
=
False
elif
params
.
get
(
"do_remote_decode"
):
# Add an empty list to worker to create event.
self
.
_reqs_need_send
[
request
.
request_id
]
=
[]
def
build_connector_meta
(
self
,
scheduler_output
:
SchedulerOutput
,
)
->
KVConnectorMetadata
:
meta
=
MooncakeConnectorMetadata
()
# Loop through scheduled reqs and convert to RecvReqMeta.
if
self
.
kv_role
!=
"kv_producer"
:
for
req_id
,
(
req
,
block_ids
)
in
self
.
_reqs_need_recv
.
items
():
assert
req
.
kv_transfer_params
is
not
None
meta
.
add_new_req
(
request_id
=
req_id
,
local_block_ids
=
block_ids
,
kv_transfer_params
=
req
.
kv_transfer_params
,
)
self
.
_reqs_need_recv
.
clear
()
if
self
.
kv_role
!=
"kv_consumer"
:
for
req_id
,
block_ids
in
self
.
_reqs_need_send
.
items
():
meta
.
add_new_req
(
request_id
=
req_id
,
local_block_ids
=
block_ids
,
kv_transfer_params
=
{},
load_remote_cache
=
False
,
)
self
.
_reqs_need_send
.
clear
()
return
meta
def
request_finished
(
self
,
request
:
"Request"
,
block_ids
:
list
[
int
],
)
->
tuple
[
bool
,
dict
[
str
,
Any
]
|
None
]:
"""
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
params
=
request
.
kv_transfer_params
logger
.
debug
(
"MooncakeConnector request_finished, request_status=%s, "
"kv_transfer_params=%s"
,
request
.
status
,
params
,
)
if
not
params
:
return
False
,
None
if
params
.
get
(
"do_remote_prefill"
):
# If do_remote_prefill is still True when the request is finished,
# update_state_after_alloc must not have been called (the request
# must have been aborted before it was scheduled).
# To avoid stranding the prefill blocks in the prefill instance,
# we must add empty block_ids to _reqs_need_recv so that our
# worker side will notify and free blocks in the prefill instance.
assert
self
.
kv_role
!=
"kv_producer"
self
.
_reqs_need_recv
[
request
.
request_id
]
=
(
request
,
[])
params
[
"do_remote_prefill"
]
=
False
return
False
,
None
if
(
not
params
.
get
(
"do_remote_decode"
)
or
request
.
status
!=
RequestStatus
.
FINISHED_LENGTH_CAPPED
):
return
False
,
None
assert
self
.
kv_role
!=
"kv_consumer"
# TODO: check whether block_ids actually ever be 0. If not we could
# remove the conditional below
delay_free_blocks
=
len
(
block_ids
)
>
0
if
delay_free_blocks
:
self
.
_reqs_need_send
[
request
.
request_id
]
=
block_ids
return
delay_free_blocks
,
dict
(
do_remote_prefill
=
True
,
do_remote_decode
=
False
,
remote_host
=
self
.
side_channel_host
,
remote_port
=
self
.
side_channel_port
,
)
class
MooncakeConnectorWorker
:
"""Implementation of Worker side methods"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_id
:
str
):
logger
.
info
(
"Initializing Mooncake Transfer Engine worker %s"
,
engine_id
)
self
.
vllm_config
=
vllm_config
self
.
engine
=
TransferEngine
()
self
.
hostname
=
get_ip
()
ret_value
=
self
.
engine
.
initialize
(
self
.
hostname
,
"P2PHANDSHAKE"
,
"rdma"
,
""
)
if
ret_value
!=
0
:
raise
RuntimeError
(
"Mooncake Transfer Engine initialization failed."
)
self
.
rpc_port
=
self
.
engine
.
get_rpc_port
()
logger
.
debug
(
"Mooncake Transfer Engine initialized at %s:%d"
,
self
.
hostname
,
self
.
rpc_port
,
)
# Mooncake handshake port.
self
.
side_channel_port
:
int
=
get_mooncake_side_channel_port
(
vllm_config
)
self
.
engine_id
:
EngineId
=
engine_id
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
world_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_group
=
get_tp_group
()
self
.
num_blocks
=
0
assert
vllm_config
.
kv_transfer_config
self
.
kv_role
=
vllm_config
.
kv_transfer_config
.
kv_role
self
.
num_workers
=
vllm_config
.
kv_transfer_config
.
kv_connector_extra_config
.
get
(
"num_workers"
,
10
)
self
.
kv_caches_base_addr
:
list
[
int
]
=
[]
self
.
device_kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
reqs_need_send
:
SendReqMeta
=
SendReqMeta
(
reqs
=
{},
lock
=
threading
.
Lock
())
# For kv_both, we will act both prefiller and decoder.
if
self
.
kv_role
!=
"kv_consumer"
:
# Background thread for sending kvcaches to D.
self
.
_mooncake_sender_t
:
threading
.
Thread
|
None
=
None
# Background thread for processing new sending requests.
self
.
_sender_executor
=
ThreadPoolExecutor
(
max_workers
=
self
.
num_workers
,
thread_name_prefix
=
"vllm-mooncake-sender"
)
logger
.
debug
(
"Mooncake Prefiller: use %d workers to send kvcaches"
,
self
.
num_workers
)
if
self
.
kv_role
!=
"kv_producer"
:
self
.
receiver_loop
=
asyncio
.
new_event_loop
()
self
.
_mooncake_receiver_t
=
threading
.
Thread
(
target
=
self
.
_receiver_loop
,
args
=
(
self
.
receiver_loop
,),
daemon
=
True
)
self
.
_mooncake_receiver_t
.
start
()
logger
.
debug
(
"Mooncake Decoder: start receiver thread"
)
self
.
finished_sending_reqs
:
FinishedSendReqSet
=
FinishedSendReqSet
(
set
(),
threading
.
Lock
()
)
self
.
finished_recving_reqs
:
FinishedReceiveReqSet
=
FinishedReceiveReqSet
(
set
(),
asyncio
.
Lock
()
)
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
use_mla
=
self
.
model_config
.
use_mla
backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
dtype
,
self
.
cache_config
.
cache_dtype
,
self
.
block_size
,
use_mla
=
self
.
use_mla
,
)
self
.
backend_name
=
backend
.
get_name
()
self
.
kv_cache_layout
=
get_kv_cache_layout
()
logger
.
debug
(
"Detected attention backend %s"
,
self
.
backend_name
)
logger
.
debug
(
"Detected kv cache layout %s"
,
self
.
kv_cache_layout
)
self
.
_tp_size
:
dict
[
EngineId
,
int
]
=
{
self
.
engine_id
:
self
.
world_size
}
self
.
_block_size
:
dict
[
EngineId
,
int
]
=
{
self
.
engine_id
:
self
.
block_size
}
self
.
kv_topo
=
TpKVTopology
(
tp_rank
=
self
.
tp_rank
,
engine_id
=
self
.
engine_id
,
remote_tp_size
=
self
.
_tp_size
,
# shared state
remote_block_size
=
self
.
_block_size
,
# shared state
is_mla
=
self
.
use_mla
,
total_num_kv_heads
=
self
.
model_config
.
get_total_num_kv_heads
(),
attn_backend
=
backend
,
)
self
.
_use_pallas
=
self
.
kv_topo
.
_use_pallas
self
.
zmq_ctx
=
zmq
.
Context
()
self
.
async_zmq_ctx
=
zmq
.
asyncio
.
Context
()
self
.
_encoder
=
msgspec
.
msgpack
.
Encoder
()
self
.
_decoder
=
msgspec
.
msgpack
.
Decoder
(
MooncakeAgentMetadata
)
def
__del__
(
self
):
self
.
shutdown
()
def
shutdown
(
self
):
"""Cleanup background threads on destruction."""
self
.
zmq_ctx
.
term
()
self
.
async_zmq_ctx
.
term
()
if
self
.
kv_role
!=
"kv_consumer"
:
self
.
_sender_executor
.
shutdown
(
wait
=
False
)
if
self
.
_mooncake_sender_t
:
self
.
_mooncake_sender_t
.
join
()
if
self
.
kv_role
!=
"kv_producer"
and
self
.
receiver_loop
.
is_running
():
self
.
receiver_loop
.
call_soon_threadsafe
(
self
.
receiver_loop
.
stop
)
self
.
_mooncake_receiver_t
.
join
()
def
_receiver_loop
(
self
,
loop
:
asyncio
.
AbstractEventLoop
):
asyncio
.
set_event_loop
(
loop
)
loop
.
run_forever
()
def
_mooncake_sender
(
self
,
ready_event
:
threading
.
Event
,
base_port
:
int
,
tp_rank
:
int
):
"""
Background thread that listens for Mooncake requests, dispatches them
to a thread pool, and sends acknowledgments upon completion.
"""
frontend_path
=
make_zmq_path
(
"tcp"
,
self
.
hostname
,
base_port
+
tp_rank
)
frontend
=
make_zmq_socket
(
self
.
zmq_ctx
,
frontend_path
,
zmq
.
ROUTER
)
logger
.
debug
(
"Mooncake sender starting listening on path: %s"
,
frontend_path
)
backend_path
=
make_zmq_path
(
"inproc"
,
str
(
uuid
.
uuid4
()))
backend
=
make_zmq_socket
(
self
.
zmq_ctx
,
backend_path
,
zmq
.
PULL
)
poller
=
zmq
.
Poller
()
poller
.
register
(
frontend
,
zmq
.
POLLIN
)
poller
.
register
(
backend
,
zmq
.
POLLIN
)
ready_event
.
set
()
try
:
while
True
:
sockets
=
dict
(
poller
.
poll
())
if
frontend
in
sockets
:
identity
,
_
,
metadata_bytes
=
frontend
.
recv_multipart
()
self
.
_sender_executor
.
submit
(
self
.
_sender_worker
,
identity
,
metadata_bytes
,
backend_path
,
)
if
backend
in
sockets
:
identity
,
status
=
backend
.
recv_multipart
()
frontend
.
send_multipart
((
identity
,
b
""
,
status
))
except
zmq
.
ContextTerminated
:
logger
.
debug
(
"ZMQ context terminated, exiting Mooncake sender thread."
)
except
Exception
as
e
:
logger
.
error
(
"Error in Mooncake sender thread: %s. Exiting thread."
,
str
(
e
))
finally
:
frontend
.
close
()
backend
.
close
()
def
_sender_worker
(
self
,
identity
:
bytes
,
metadata_bytes
:
bytes
,
worker_channel_path
:
str
):
status
=
TRANS_ERROR
try
:
metadata
=
self
.
_decoder
.
decode
(
metadata_bytes
)
self
.
send_kv_to_decode
(
metadata
)
status
=
TRANS_DONE
except
Exception
as
e
:
logger
.
error
(
"Error processing Mooncake handshake: %s"
,
e
)
finally
:
pusher
=
make_zmq_socket
(
self
.
zmq_ctx
,
worker_channel_path
,
zmq
.
PUSH
)
try
:
pusher
.
send_multipart
((
identity
,
status
))
except
zmq
.
ZMQError
as
e
:
logger
.
warning
(
"Internal error, maybe the server is shutting down. Error: %s"
,
e
,
)
finally
:
pusher
.
close
()
def
send_kv_to_decode
(
self
,
meta
:
MooncakeAgentMetadata
):
send_reqs
:
list
[
tuple
[
ReqId
,
SendBlockMeta
]]
=
[]
with
self
.
reqs_need_send
.
lock
:
for
req_id
in
meta
.
request_ids
:
send_meta
=
self
.
reqs_need_send
.
reqs
.
get
(
req_id
)
if
send_meta
is
None
:
logger
.
warning
(
"Request %s not found in reqs_need_send"
,
req_id
)
return
# Mark it as not expired. We will send it now.
send_meta
.
expire_time
=
float
(
"inf"
)
send_reqs
.
append
((
req_id
,
send_meta
))
self
.
_send_blocks
(
send_reqs
,
meta
)
with
self
.
reqs_need_send
.
lock
:
for
req_id
in
meta
.
request_ids
:
del
self
.
reqs_need_send
.
reqs
[
req_id
]
with
self
.
finished_sending_reqs
.
lock
:
self
.
finished_sending_reqs
.
set
.
update
(
meta
.
request_ids
)
def
_send_blocks
(
self
,
send_reqs
:
list
[
tuple
[
ReqId
,
SendBlockMeta
]],
agent_meta
:
MooncakeAgentMetadata
,
):
src_ptrs
=
[]
dst_ptrs
=
[]
lengths
=
[]
local_base_addr
=
self
.
kv_caches_base_addr
remote_base_addr
=
agent_meta
.
kv_caches_base_addr
block_len
=
self
.
block_len
remote_session
=
f
"
{
agent_meta
.
remote_hostname
}
:
{
agent_meta
.
remote_port
}
"
assert
len
(
send_reqs
)
==
len
(
agent_meta
.
block_ids
)
for
(
req_id
,
send_meta
),
remote_block_ids
in
zip
(
send_reqs
,
agent_meta
.
block_ids
):
send_meta
.
ready
.
wait
()
num_remote_blocks
=
len
(
remote_block_ids
)
if
num_remote_blocks
==
0
:
continue
local_block_ids
=
send_meta
.
local_block_ids
# Partial prefix cache hit: just read uncomputed blocks.
num_local_blocks
=
len
(
local_block_ids
)
assert
num_local_blocks
>=
num_remote_blocks
if
num_local_blocks
>
num_remote_blocks
:
local_block_ids
=
local_block_ids
[
-
num_remote_blocks
:]
# Group by indices
group_local_block_ids
,
group_remote_block_ids
=
group_concurrent_contiguous
(
local_block_ids
,
remote_block_ids
)
for
local_layer_addr
,
remote_layer_addr
in
zip
(
local_base_addr
,
remote_base_addr
):
for
group_local_block_id
,
group_remote_block_id
in
zip
(
group_local_block_ids
,
group_remote_block_ids
):
src_ptrs
.
append
(
local_layer_addr
+
group_local_block_id
[
0
]
*
block_len
)
dst_ptrs
.
append
(
remote_layer_addr
+
group_remote_block_id
[
0
]
*
block_len
)
lengths
.
append
(
block_len
*
len
(
group_local_block_id
))
logger
.
debug
(
"Sending kv_caches for request %s (%d blocks) to %s"
,
req_id
,
num_remote_blocks
,
remote_session
,
)
start_time
=
time
.
perf_counter
()
ret_value
=
self
.
engine
.
batch_transfer_sync_write
(
remote_session
,
src_ptrs
,
dst_ptrs
,
lengths
)
if
ret_value
!=
0
:
raise
RuntimeError
(
f
"Error in batch_transfer_sync_write:
{
ret_value
}
"
)
logger
.
debug
(
"Sending to %s done, took %s"
,
remote_session
,
time
.
perf_counter
()
-
start_time
,
)
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
"""Register the KV Cache data in mooncake."""
logger
.
info
(
"Registering KV_Caches. use_mla: %s"
,
self
.
use_mla
)
kv_data_ptrs
=
[]
kv_data_lens
=
[]
seen_base_addresses
=
[]
split_k_and_v
=
self
.
kv_topo
.
split_k_and_v
tensor_size_bytes
=
None
for
layer_name
,
cache_or_caches
in
kv_caches
.
items
():
logger
.
debug
(
"registering layer %s with shape %s"
,
layer_name
,
cache_or_caches
.
shape
)
cache_list
=
cache_or_caches
if
split_k_and_v
else
[
cache_or_caches
]
for
cache
in
cache_list
:
base_addr
=
cache
.
data_ptr
()
if
base_addr
in
seen_base_addresses
:
continue
seen_base_addresses
.
append
(
base_addr
)
curr_tensor_size_bytes
=
cache
.
nbytes
if
tensor_size_bytes
is
None
:
tensor_size_bytes
=
curr_tensor_size_bytes
self
.
num_blocks
=
cache
.
shape
[
0
]
assert
tensor_size_bytes
==
curr_tensor_size_bytes
,
(
"All kv cache tensors must have the same size"
)
kernel_block_size
=
cache
.
shape
[
-
2
if
self
.
use_mla
else
-
3
]
assert
self
.
block_size
==
kernel_block_size
kv_data_ptrs
.
append
(
base_addr
)
kv_data_lens
.
append
(
tensor_size_bytes
)
self
.
kv_caches_base_addr
=
seen_base_addresses
ret_value
=
self
.
engine
.
batch_register_memory
(
kv_data_ptrs
,
kv_data_lens
)
if
ret_value
!=
0
:
raise
RuntimeError
(
"Mooncake batch memory registration failed."
)
assert
tensor_size_bytes
is
not
None
assert
self
.
num_blocks
!=
0
assert
tensor_size_bytes
%
self
.
num_blocks
==
0
self
.
block_len
=
tensor_size_bytes
//
self
.
num_blocks
self
.
device_kv_caches
=
kv_caches
logger
.
debug
(
"registered num_blocks=%d block_len=%d"
,
self
.
num_blocks
,
self
.
block_len
)
# No need to launch server for D node.
if
self
.
kv_role
==
"kv_consumer"
:
return
ready_event
=
threading
.
Event
()
self
.
_mooncake_sender_t
=
threading
.
Thread
(
target
=
self
.
_mooncake_sender
,
args
=
(
ready_event
,
self
.
side_channel_port
,
self
.
tp_rank
),
daemon
=
True
,
name
=
"mooncake_sender"
,
)
self
.
_mooncake_sender_t
.
start
()
ready_event
.
wait
()
# Wait for listener ZMQ socket to be ready.
async
def
fetch_finished_recving_reqs
(
self
)
->
set
[
ReqId
]:
async
with
self
.
finished_recving_reqs
.
lock
:
finished_recving_reqs
=
self
.
finished_recving_reqs
.
set
self
.
finished_recving_reqs
.
set
=
set
()
return
finished_recving_reqs
def
get_finished
(
self
)
->
tuple
[
set
[
str
]
|
None
,
set
[
str
]
|
None
]:
"""
Get requests that are done sending or recving on this specific worker.
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
"""
fut
=
None
if
self
.
kv_role
!=
"kv_producer"
:
fut
=
asyncio
.
run_coroutine_threadsafe
(
self
.
fetch_finished_recving_reqs
(),
self
.
receiver_loop
)
if
self
.
kv_role
!=
"kv_consumer"
:
with
self
.
finished_sending_reqs
.
lock
:
finished_sending_reqs
=
self
.
finished_sending_reqs
.
set
self
.
finished_sending_reqs
.
set
=
set
()
else
:
finished_sending_reqs
=
set
()
finished_recving_reqs
=
fut
.
result
()
if
fut
else
set
()
if
finished_sending_reqs
or
finished_recving_reqs
:
logger
.
debug
(
"Rank %s, get_finished: %s requests done sending "
"and %s requests done recving"
,
self
.
tp_rank
,
len
(
finished_sending_reqs
),
len
(
finished_recving_reqs
),
)
# Handle timeout to avoid stranding blocks on remote.
now
=
time
.
perf_counter
()
with
self
.
reqs_need_send
.
lock
:
expired_reqs
=
[
req_id
for
req_id
,
send_meta
in
self
.
reqs_need_send
.
reqs
.
items
()
if
send_meta
.
expire_time
<
now
]
for
req_id
in
expired_reqs
:
logger
.
warning
(
"Request %s timed out after %d seconds without "
"being sent. Freeing its blocks on the producer side."
,
req_id
,
envs
.
VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
,
)
del
self
.
reqs_need_send
.
reqs
[
req_id
]
if
expired_reqs
:
finished_sending_reqs
.
update
(
expired_reqs
)
return
finished_sending_reqs
or
None
,
finished_recving_reqs
or
None
async
def
receive_kv
(
self
,
path
:
str
,
req_blocks
:
list
[
tuple
[
str
,
list
[
int
]]]):
req_ids
,
block_ids
=
map
(
list
,
zip
(
*
req_blocks
))
metadata
=
MooncakeAgentMetadata
(
remote_hostname
=
self
.
hostname
,
remote_port
=
self
.
rpc_port
,
request_ids
=
req_ids
,
kv_caches_base_addr
=
self
.
kv_caches_base_addr
,
block_ids
=
block_ids
,
)
encoded_data
=
self
.
_encoder
.
encode
(
metadata
)
logger
.
debug
(
"Size of encoded MooncakeAgentMetadata: %d bytes"
,
len
(
encoded_data
)
)
logger
.
debug
(
"Sending kv transfer request for %s on path: %s"
,
req_ids
,
path
)
# Send query for the request.
sock
:
zmq
.
asyncio
.
Socket
=
make_zmq_socket
(
self
.
async_zmq_ctx
,
path
,
zmq
.
REQ
,
bind
=
False
,
linger
=
0
)
sock
.
setsockopt
(
zmq
.
RCVTIMEO
,
60000
)
try
:
await
sock
.
send
(
encoded_data
)
ret_msg
=
await
sock
.
recv
()
if
ret_msg
!=
TRANS_DONE
:
logger
.
error
(
"Error happens during tranfering kvcache for %s, see logs in prefiller."
,
# noqa: E501
req_ids
,
)
return
except
zmq
.
ContextTerminated
:
logger
.
debug
(
"ZMQ context terminated, exiting Mooncake receiver thread."
)
except
Exception
as
e
:
logger
.
error
(
"MooncakeAgentMetadata transfer failed for %s: %s"
,
req_ids
,
e
)
return
finally
:
sock
.
close
()
async
with
self
.
finished_recving_reqs
.
lock
:
self
.
finished_recving_reqs
.
set
.
update
(
req_ids
)
logger
.
debug
(
"pulling kv_caches for %s finished"
,
req_ids
)
def
group_kv_pull
(
self
,
metadata
:
MooncakeConnectorMetadata
):
kv_pulls
=
defaultdict
(
list
)
for
req_id
,
meta
in
metadata
.
reqs_to_recv
.
items
():
logger
.
debug
(
"start_load_kv for request %s from remote engine. "
"Num local_block_ids: %s."
,
req_id
,
len
(
meta
.
local_block_ids
),
)
path
=
make_zmq_path
(
"tcp"
,
meta
.
remote_host
,
meta
.
remote_port
+
self
.
tp_rank
)
kv_pulls
[
path
].
append
((
req_id
,
meta
.
local_block_ids
))
return
kv_pulls
def
start_load_kv
(
self
,
metadata
:
MooncakeConnectorMetadata
):
if
self
.
kv_role
!=
"kv_producer"
:
kv_pulls
=
self
.
group_kv_pull
(
metadata
)
for
path
,
req_blocks
in
kv_pulls
.
items
():
asyncio
.
run_coroutine_threadsafe
(
self
.
receive_kv
(
path
,
req_blocks
),
self
.
receiver_loop
)
if
self
.
kv_role
!=
"kv_consumer"
:
with
self
.
reqs_need_send
.
lock
:
for
req_id
,
block_ids
in
metadata
.
reqs_to_send
.
items
():
if
block_ids
:
# Already gone through request_finished()
send_meta
=
self
.
reqs_need_send
.
reqs
[
req_id
]
send_meta
.
local_block_ids
=
block_ids
send_meta
.
ready
.
set
()
send_meta
.
expire_time
=
(
time
.
perf_counter
()
+
envs
.
VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
)
else
:
# From update_state_after_alloc(),
# but not reach request_finished() yet
self
.
reqs_need_send
.
reqs
[
req_id
]
=
SendBlockMeta
(
local_block_ids
=
[],
ready
=
threading
.
Event
()
)
def
group_concurrent_contiguous
(
src_indices
:
list
[
int
],
dst_indices
:
list
[
int
]
)
->
tuple
[
list
[
list
[
int
]],
list
[
list
[
int
]]]:
"""Vectorised NumPy implementation."""
if
len
(
src_indices
)
==
0
:
return
[],
[]
brk
=
np
.
where
((
np
.
diff
(
src_indices
)
!=
1
)
|
(
np
.
diff
(
dst_indices
)
!=
1
))[
0
]
+
1
src_groups
=
np
.
split
(
src_indices
,
brk
)
dst_groups
=
np
.
split
(
dst_indices
,
brk
)
src_groups
=
[
g
.
tolist
()
for
g
in
src_groups
]
dst_groups
=
[
g
.
tolist
()
for
g
in
dst_groups
]
return
src_groups
,
dst_groups
def
get_mooncake_side_channel_port
(
vllm_config
:
VllmConfig
)
->
int
:
# This logic is now centralized
return
(
envs
.
VLLM_MOONCAKE_BOOTSTRAP_PORT
+
vllm_config
.
parallel_config
.
data_parallel_rank
*
vllm_config
.
parallel_config
.
tensor_parallel_size
)
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
View file @
8d75f22e
...
@@ -452,3 +452,7 @@ class MultiConnector(KVConnectorBase_V1):
...
@@ -452,3 +452,7 @@ class MultiConnector(KVConnectorBase_V1):
per_engine_labelvalues
,
per_engine_labelvalues
,
prom_metrics
,
prom_metrics
,
)
)
def
reset_cache
(
self
)
->
bool
:
results
=
[
c
.
reset_cache
()
is
not
False
for
c
in
self
.
_connectors
]
return
all
(
results
)
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
8d75f22e
...
@@ -20,10 +20,10 @@ import torch
...
@@ -20,10 +20,10 @@ import torch
import
zmq
import
zmq
from
vllm
import
envs
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionBackend
,
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.utils
import
TpKVTopology
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
CopyBlocksOp
,
CopyBlocksOp
,
KVConnectorBase_V1
,
KVConnectorBase_V1
,
...
@@ -55,10 +55,26 @@ if TYPE_CHECKING:
...
@@ -55,10 +55,26 @@ if TYPE_CHECKING:
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
Transfer
=
tuple
[
int
,
float
]
# (xfer_handle, start_time)
Transfer
Handle
=
int
EngineId
=
str
EngineId
=
str
ReqId
=
str
ReqId
=
str
#
# NIXL Connector Version
#
# Increment this version whenever there is an incompatible change to:
# - NixlAgentMetadata schema
# - kv_transfer_params schema or semantics
# - NIXL transfer protocol or wire format
# - KV cache memory layout or block organization
# - Any other change that breaks P/D interoperability
#
# Version History:
# 1: Initial version with compatibility checking
# 2: Add remote_request_id to kv_transfer_params
#
NIXL_CONNECTOR_VERSION
:
int
=
2
GET_META_MSG
=
b
"get_meta_msg"
GET_META_MSG
=
b
"get_meta_msg"
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -97,18 +113,95 @@ _NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
...
@@ -97,18 +113,95 @@ _NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
@
dataclass
@
dataclass
class
NixlAgentMetadata
(
KVConnectorHandshakeMetadata
)
:
class
NixlAgentMetadata
:
engine_id
:
str
engine_id
:
str
agent_metadata
:
bytes
agent_metadata
:
bytes
kv_caches_base_addr
:
list
[
int
]
kv_caches_base_addr
:
list
[
int
]
device_id
:
int
device_id
:
int
num_blocks
:
int
num_blocks
:
int
block_lens
:
list
[
int
]
block_lens
:
list
[
int
]
attn_backend_name
:
str
kv_cache_layout
:
str
kv_cache_layout
:
str
block_size
:
int
block_size
:
int
@
dataclass
class
NixlHandshakePayload
(
KVConnectorHandshakeMetadata
):
"""
Wrapper for NIXL handshake sent over the wire.
Enables two-phase decoding for graceful compatibility checking:
1. Decode NixlHandshakePayload to get compatibility_hash
2. Compute local hash and compare
3. Only if hashes match, decode agent_metadata_bytes
This prevents decoder errors when NixlAgentMetadata schema is
incompatible, allowing graceful failure with clear error message.
"""
compatibility_hash
:
str
agent_metadata_bytes
:
bytes
# NixlAgentMetadata encoded
def
compute_nixl_compatibility_hash
(
vllm_config
:
VllmConfig
,
attn_backend_name
:
str
)
->
str
:
"""
Compute compatibility hash for NIXL KV transfer.
Hash only the factors that affect whether two NIXL instances can
successfully transfer KV cache data.
Factors included:
- vLLM version and NIXL connector version
- Model architecture (name, dtype, KV heads, layers)
- KV cache format (dtype, sliding window)
- Attention backend
Note: Factors like tensor_parallel_size, block_size, and kv_cache_layout
are validated at runtime in _validate_remote_agent_handshake and are not
included in this hash to support heterogeneous deployments.
Note - the set of factors are likely to evolve significantly over
time to be more or less permissive.
Returns:
SHA-256 hex digest
"""
from
vllm
import
__version__
as
vllm_version
from
vllm.config.utils
import
hash_factors
model_config
=
vllm_config
.
model_config
cache_config
=
vllm_config
.
cache_config
factors
=
{
# Version compatibility
"vllm_version"
:
vllm_version
,
"nixl_connector_version"
:
NIXL_CONNECTOR_VERSION
,
# Model architecture - affects KV cache shape
"model"
:
model_config
.
model
,
"dtype"
:
str
(
model_config
.
dtype
),
"num_kv_heads"
:
model_config
.
get_total_num_kv_heads
(),
"head_size"
:
model_config
.
get_head_size
(),
"num_hidden_layers"
:
model_config
.
get_total_num_hidden_layers
(),
# Attention backend and KV cache dtype affect memory layout
"attn_backend_name"
:
attn_backend_name
,
"cache_dtype"
:
str
(
cache_config
.
cache_dtype
),
}
compat_hash
=
hash_factors
(
factors
)
logger
.
debug
(
"NIXL compatibility hash: %s (model=%s, dtype=%s, num_kv_heads=%d, "
"cache_dtype=%s, attn_backend=%s)"
,
compat_hash
,
factors
[
"model"
],
factors
[
"dtype"
],
factors
[
"num_kv_heads"
],
factors
[
"cache_dtype"
],
attn_backend_name
,
)
return
compat_hash
@
dataclass
@
dataclass
class
ReqMeta
:
class
ReqMeta
:
local_block_ids
:
list
[
int
]
local_block_ids
:
list
[
int
]
...
@@ -118,6 +211,7 @@ class ReqMeta:
...
@@ -118,6 +211,7 @@ class ReqMeta:
remote_host
:
str
remote_host
:
str
remote_port
:
int
remote_port
:
int
remote_engine_id
:
str
remote_engine_id
:
str
remote_request_id
:
str
tp_size
:
int
tp_size
:
int
...
@@ -144,6 +238,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
...
@@ -144,6 +238,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
local_physical_block_ids
=
local_block_ids
,
local_physical_block_ids
=
local_block_ids
,
remote_block_ids
=
kv_transfer_params
[
"remote_block_ids"
],
remote_block_ids
=
kv_transfer_params
[
"remote_block_ids"
],
remote_engine_id
=
kv_transfer_params
[
"remote_engine_id"
],
remote_engine_id
=
kv_transfer_params
[
"remote_engine_id"
],
remote_request_id
=
kv_transfer_params
[
"remote_request_id"
],
remote_host
=
kv_transfer_params
[
"remote_host"
],
remote_host
=
kv_transfer_params
[
"remote_host"
],
remote_port
=
kv_transfer_params
[
"remote_port"
],
remote_port
=
kv_transfer_params
[
"remote_port"
],
# P workers don't need to receive tp_size from proxy here.
# P workers don't need to receive tp_size from proxy here.
...
@@ -396,14 +491,14 @@ class NixlConnectorScheduler:
...
@@ -396,14 +491,14 @@ class NixlConnectorScheduler:
encoded_data
:
dict
[
int
,
bytes
]
=
{}
encoded_data
:
dict
[
int
,
bytes
]
=
{}
encoder
=
msgspec
.
msgpack
.
Encoder
()
encoder
=
msgspec
.
msgpack
.
Encoder
()
for
tp_rank
,
rank_metadata
in
metadata
.
items
():
for
tp_rank
,
rank_metadata
in
metadata
.
items
():
if
not
isinstance
(
rank_metadata
,
Nixl
AgentMetadata
):
if
not
isinstance
(
rank_metadata
,
Nixl
HandshakePayload
):
raise
ValueError
(
raise
ValueError
(
"NixlConnectorScheduler expects Nixl
AgentMetadata
for "
"NixlConnectorScheduler expects Nixl
HandshakePayload
for "
"handshake metadata."
"handshake metadata."
)
)
encoded_data
[
tp_rank
]
=
encoder
.
encode
(
rank_metadata
)
encoded_data
[
tp_rank
]
=
encoder
.
encode
(
rank_metadata
)
logger
.
debug
(
logger
.
debug
(
"Tp rank %d: encoded Nixl
AgentMetadata
size: %s bytes"
,
"Tp rank %d: encoded Nixl
HandshakePayload
size: %s bytes"
,
tp_rank
,
tp_rank
,
str
(
len
(
encoded_data
[
tp_rank
])),
str
(
len
(
encoded_data
[
tp_rank
])),
)
)
...
@@ -530,7 +625,12 @@ class NixlConnectorScheduler:
...
@@ -530,7 +625,12 @@ class NixlConnectorScheduler:
if
params
.
get
(
"remote_block_ids"
):
if
params
.
get
(
"remote_block_ids"
):
if
all
(
if
all
(
p
in
params
p
in
params
for
p
in
(
"remote_engine_id"
,
"remote_host"
,
"remote_port"
)
for
p
in
(
"remote_engine_id"
,
"remote_request_id"
,
"remote_host"
,
"remote_port"
,
)
):
):
# If remote_blocks and num_external_tokens = 0, we have
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# a full prefix cache hit on the D worker. We need to call
...
@@ -659,6 +759,7 @@ class NixlConnectorScheduler:
...
@@ -659,6 +759,7 @@ class NixlConnectorScheduler:
do_remote_decode
=
False
,
do_remote_decode
=
False
,
remote_block_ids
=
block_ids
,
remote_block_ids
=
block_ids
,
remote_engine_id
=
self
.
engine_id
,
remote_engine_id
=
self
.
engine_id
,
remote_request_id
=
request
.
request_id
,
remote_host
=
self
.
side_channel_host
,
remote_host
=
self
.
side_channel_host
,
remote_port
=
self
.
side_channel_port
,
remote_port
=
self
.
side_channel_port
,
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
,
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
,
...
@@ -668,128 +769,6 @@ class NixlConnectorScheduler:
...
@@ -668,128 +769,6 @@ class NixlConnectorScheduler:
class
NixlConnectorWorker
:
class
NixlConnectorWorker
:
"""Implementation of Worker side methods"""
"""Implementation of Worker side methods"""
@
dataclass
class
TpKVTopology
:
"""
Helper class for tensor parallel and KV topology information for
mapping between local and remote TP workers.
"""
tp_rank
:
int
remote_tp_size
:
dict
[
EngineId
,
int
]
is_mla
:
bool
total_num_kv_heads
:
int
attn_backend
:
type
[
AttentionBackend
]
engine_id
:
EngineId
remote_block_size
:
dict
[
EngineId
,
int
]
def
__post_init__
(
self
):
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
=
1
,
block_size
=
16
,
num_kv_heads
=
1
,
head_size
=
1
)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below.
self
.
_is_kv_layout_blocks_first
=
(
len
(
kv_cache_shape
)
==
5
and
kv_cache_shape
[
0
]
==
1
)
attn_backend
=
AttentionBackendEnum
[
self
.
attn_backend
.
get_name
()]
self
.
_use_pallas
=
attn_backend
==
AttentionBackendEnum
.
PALLAS
@
property
def
is_kv_layout_blocks_first
(
self
)
->
bool
:
return
self
.
_is_kv_layout_blocks_first
@
property
def
split_k_and_v
(
self
)
->
bool
:
# Whether to register regions for K and V separately (when present).
return
not
(
self
.
is_mla
or
self
.
_use_pallas
or
self
.
is_kv_layout_blocks_first
)
@
property
def
tp_size
(
self
)
->
int
:
return
self
.
remote_tp_size
[
self
.
engine_id
]
@
property
def
block_size
(
self
)
->
int
:
return
self
.
remote_block_size
[
self
.
engine_id
]
def
tp_ratio
(
self
,
remote_tp_size
:
int
,
)
->
int
:
"""
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
"""
assert
self
.
tp_size
%
remote_tp_size
==
0
,
(
f
"Local tensor parallel size
{
self
.
tp_size
}
is not divisible "
f
"by remote tensor parallel size
{
remote_tp_size
}
."
)
return
self
.
tp_size
//
remote_tp_size
def
block_size_ratio
(
self
,
remote_block_size
:
int
,
)
->
float
:
"""
Calculate the block size ratio between local and remote TP.
"""
assert
self
.
block_size
%
remote_block_size
==
0
,
(
f
"Local block size
{
self
.
block_size
}
is not divisible "
f
"by remote block size
{
remote_block_size
}
or vice versa."
)
return
self
.
block_size
//
remote_block_size
def
tp_ratio_from_engine_id
(
self
,
remote_engine_id
:
EngineId
,
)
->
int
:
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
return
self
.
tp_ratio
(
remote_tp_size
)
def
block_size_ratio_from_engine_id
(
self
,
remote_engine_id
:
EngineId
,
)
->
float
:
remote_block_size
=
self
.
remote_block_size
[
remote_engine_id
]
return
self
.
block_size_ratio
(
remote_block_size
)
def
is_kv_replicated
(
self
,
engine_id
:
EngineId
)
->
bool
:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
"""
tp_size
=
self
.
remote_tp_size
[
engine_id
]
return
tp_size
//
self
.
total_num_kv_heads
>=
1
def
replicates_kv_cache
(
self
,
remote_engine_id
:
EngineId
)
->
bool
:
# MLA is always replicated as the hidden dim can't be split.
return
self
.
is_mla
or
self
.
is_kv_replicated
(
remote_engine_id
)
def
get_target_remote_rank
(
self
,
remote_tp_size
:
int
,
)
->
int
:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
"""
tp_ratio
=
self
.
tp_ratio
(
remote_tp_size
)
return
self
.
tp_rank
//
tp_ratio
def
get_target_remote_rank_from_engine_id
(
self
,
remote_engine_id
:
EngineId
,
)
->
int
:
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
return
self
.
get_target_remote_rank
(
remote_tp_size
)
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_id
:
str
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_id
:
str
):
if
NixlWrapper
is
None
:
if
NixlWrapper
is
None
:
logger
.
error
(
"NIXL is not available"
)
logger
.
error
(
"NIXL is not available"
)
...
@@ -904,7 +883,7 @@ class NixlConnectorWorker:
...
@@ -904,7 +883,7 @@ class NixlConnectorWorker:
# In progress transfers.
# In progress transfers.
# [req_id -> list[handle]]
# [req_id -> list[handle]]
self
.
_recving_metadata
:
dict
[
ReqId
,
ReqMeta
]
=
{}
self
.
_recving_metadata
:
dict
[
ReqId
,
ReqMeta
]
=
{}
self
.
_recving_transfers
=
defaultdict
[
ReqId
,
list
[
Transfer
]](
list
)
self
.
_recving_transfers
=
defaultdict
[
ReqId
,
list
[
Transfer
Handle
]](
list
)
# Track the expiration time of requests that are waiting to be sent.
# Track the expiration time of requests that are waiting to be sent.
self
.
_reqs_to_send
:
dict
[
ReqId
,
float
]
=
{}
self
.
_reqs_to_send
:
dict
[
ReqId
,
float
]
=
{}
# Set of requests that have been part of a batch, regardless of status.
# Set of requests that have been part of a batch, regardless of status.
...
@@ -916,7 +895,7 @@ class NixlConnectorWorker:
...
@@ -916,7 +895,7 @@ class NixlConnectorWorker:
self
.
_failed_recv_reqs
:
set
[
ReqId
]
=
set
()
self
.
_failed_recv_reqs
:
set
[
ReqId
]
=
set
()
# Handshake metadata of this worker for NIXL transfers.
# Handshake metadata of this worker for NIXL transfers.
self
.
xfer_handshake_metadata
:
Nixl
AgentMetadata
|
None
=
None
self
.
xfer_handshake_metadata
:
Nixl
HandshakePayload
|
None
=
None
# Background thread for initializing new NIXL handshakes.
# Background thread for initializing new NIXL handshakes.
self
.
_handshake_initiation_executor
=
ThreadPoolExecutor
(
self
.
_handshake_initiation_executor
=
ThreadPoolExecutor
(
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
...
@@ -951,6 +930,13 @@ class NixlConnectorWorker:
...
@@ -951,6 +930,13 @@ class NixlConnectorWorker:
logger
.
debug
(
"Detected attention backend %s"
,
self
.
backend_name
)
logger
.
debug
(
"Detected attention backend %s"
,
self
.
backend_name
)
logger
.
debug
(
"Detected kv cache layout %s"
,
self
.
kv_cache_layout
)
logger
.
debug
(
"Detected kv cache layout %s"
,
self
.
kv_cache_layout
)
self
.
compat_hash
=
compute_nixl_compatibility_hash
(
self
.
vllm_config
,
self
.
backend_name
)
self
.
enforce_compat_hash
=
self
.
kv_transfer_config
.
get_from_extra_config
(
"enforce_handshake_compat"
,
True
)
self
.
_tp_size
:
dict
[
EngineId
,
int
]
=
{
self
.
engine_id
:
self
.
world_size
}
self
.
_tp_size
:
dict
[
EngineId
,
int
]
=
{
self
.
engine_id
:
self
.
world_size
}
self
.
_block_size
:
dict
[
EngineId
,
int
]
=
{
self
.
engine_id
:
self
.
block_size
}
self
.
_block_size
:
dict
[
EngineId
,
int
]
=
{
self
.
engine_id
:
self
.
block_size
}
# With heterogeneous TP, P must wait for all assigned D TP workers to
# With heterogeneous TP, P must wait for all assigned D TP workers to
...
@@ -958,7 +944,7 @@ class NixlConnectorWorker:
...
@@ -958,7 +944,7 @@ class NixlConnectorWorker:
self
.
consumer_notification_counts_by_req
=
defaultdict
[
ReqId
,
int
](
int
)
self
.
consumer_notification_counts_by_req
=
defaultdict
[
ReqId
,
int
](
int
)
self
.
xfer_stats
=
NixlKVConnectorStats
()
self
.
xfer_stats
=
NixlKVConnectorStats
()
self
.
kv_topo
=
self
.
TpKVTopology
(
self
.
kv_topo
=
TpKVTopology
(
tp_rank
=
self
.
tp_rank
,
tp_rank
=
self
.
tp_rank
,
engine_id
=
self
.
engine_id
,
engine_id
=
self
.
engine_id
,
remote_tp_size
=
self
.
_tp_size
,
# shared state
remote_tp_size
=
self
.
_tp_size
,
# shared state
...
@@ -999,14 +985,58 @@ class NixlConnectorWorker:
...
@@ -999,14 +985,58 @@ class NixlConnectorWorker:
# Set receive timeout to 5 seconds to avoid hanging on dead server
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock
.
setsockopt
(
zmq
.
RCVTIMEO
,
5000
)
# milliseconds
sock
.
setsockopt
(
zmq
.
RCVTIMEO
,
5000
)
# milliseconds
sock
.
send
(
msg
)
sock
.
send
(
msg
)
metadata_bytes
=
sock
.
recv
()
handshake_bytes
=
sock
.
recv
()
decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
metadata
=
decoder
.
decode
(
metadata_bytes
)
# Decode handshake payload to get compatibility hash
handshake_decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlHandshakePayload
)
try
:
handshake_payload
=
handshake_decoder
.
decode
(
handshake_bytes
)
except
(
msgspec
.
DecodeError
,
msgspec
.
ValidationError
)
as
e
:
raise
RuntimeError
(
f
"Failed to decode NixlHandshakePayload. This likely indicates "
f
"an incompatibility between connector version. Error:
{
e
}
"
)
from
e
got_metadata_time
=
time
.
perf_counter
()
got_metadata_time
=
time
.
perf_counter
()
logger
.
debug
(
logger
.
debug
(
"NIXL handshake: get metadata took: %s"
,
got_metadata_time
-
start_time
"NIXL handshake: get metadata took: %s"
,
got_metadata_time
-
start_time
)
)
# Check compatibility hash BEFORE decoding agent metadata
if
(
self
.
enforce_compat_hash
and
handshake_payload
.
compatibility_hash
!=
self
.
compat_hash
):
raise
RuntimeError
(
f
"NIXL compatibility hash mismatch. "
f
"Local:
{
self
.
compat_hash
}
, "
f
"Remote:
{
handshake_payload
.
compatibility_hash
}
. "
f
"Prefill and decode instances have incompatible configurations. "
f
"This may be due to: different vLLM versions, models, dtypes, "
f
"KV cache layouts, attention backends, etc. "
f
"Both instances must use identical configurations."
f
"Disable this check using "
f
'--kv-transfer-config
\'
{{"kv_connector_extra_config": '
f
'{{"enforce_handshake_compat": false}}}}
\'
'
)
logger
.
info
(
"NIXL compatibility check passed (hash: %s)"
,
handshake_payload
.
compatibility_hash
,
)
# Decode agent metadata
metadata_decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
try
:
metadata
=
metadata_decoder
.
decode
(
handshake_payload
.
agent_metadata_bytes
)
except
(
msgspec
.
DecodeError
,
msgspec
.
ValidationError
)
as
e
:
# This should not happen if hash matched
raise
RuntimeError
(
f
"Failed to decode NixlAgentMetadata. Error:
{
e
}
"
)
from
e
# Ensure engine id matches.
# Ensure engine id matches.
if
metadata
.
engine_id
!=
expected_engine_id
:
if
metadata
.
engine_id
!=
expected_engine_id
:
raise
RuntimeError
(
raise
RuntimeError
(
...
@@ -1180,14 +1210,11 @@ class NixlConnectorWorker:
...
@@ -1180,14 +1210,11 @@ class NixlConnectorWorker:
# Enable different block lengths for different layers when MLA is used.
# Enable different block lengths for different layers when MLA is used.
self
.
block_len_per_layer
=
list
[
int
]()
self
.
block_len_per_layer
=
list
[
int
]()
self
.
slot_size_per_layer
=
list
[
int
]()
# HD bytes in kv terms
self
.
slot_size_per_layer
=
list
[
int
]()
# HD bytes in kv terms
self
.
device_id
=
self
.
tp_rank
for
layer_name
,
cache_or_caches
in
xfer_buffers
.
items
():
for
layer_name
,
cache_or_caches
in
xfer_buffers
.
items
():
cache_list
=
cache_or_caches
if
split_k_and_v
else
[
cache_or_caches
]
cache_list
=
cache_or_caches
if
split_k_and_v
else
[
cache_or_caches
]
for
cache
in
cache_list
:
for
cache
in
cache_list
:
base_addr
=
cache
.
data_ptr
()
base_addr
=
cache
.
data_ptr
()
if
not
self
.
use_host_buffer
and
current_platform
.
is_cuda_alike
():
self
.
device_id
=
cache
.
device
.
index
if
base_addr
in
seen_base_addresses
:
if
base_addr
in
seen_base_addresses
:
continue
continue
...
@@ -1230,8 +1257,7 @@ class NixlConnectorWorker:
...
@@ -1230,8 +1257,7 @@ class NixlConnectorWorker:
"All kv cache tensors must have the same size"
"All kv cache tensors must have the same size"
)
)
# Need to make sure the device ID is non-negative for NIXL,
# Need to make sure the device ID is non-negative for NIXL,
# Torch uses -1 to indicate CPU tensors while NIXL uses explicit
# Torch uses -1 to indicate CPU tensors.
# memory type.
self
.
device_id
=
max
(
cache
.
get_device
(),
0
)
self
.
device_id
=
max
(
cache
.
get_device
(),
0
)
caches_data
.
append
(
caches_data
.
append
(
(
base_addr
,
curr_tensor_size_bytes
,
self
.
device_id
,
""
)
(
base_addr
,
curr_tensor_size_bytes
,
self
.
device_id
,
""
)
...
@@ -1297,19 +1323,24 @@ class NixlConnectorWorker:
...
@@ -1297,19 +1323,24 @@ class NixlConnectorWorker:
assert
len
(
self
.
block_window_per_layer
)
==
self
.
num_layers
assert
len
(
self
.
block_window_per_layer
)
==
self
.
num_layers
# After KV Caches registered, listen for new connections.
# After KV Caches registered, listen for new connections.
self
.
xfer_handshake
_metadata
=
NixlAgentMetadata
(
agent
_metadata
=
NixlAgentMetadata
(
engine_id
=
self
.
engine_id
,
engine_id
=
self
.
engine_id
,
agent_metadata
=
self
.
nixl_wrapper
.
get_agent_metadata
(),
agent_metadata
=
self
.
nixl_wrapper
.
get_agent_metadata
(),
kv_caches_base_addr
=
self
.
kv_caches_base_addr
[
self
.
engine_id
],
kv_caches_base_addr
=
self
.
kv_caches_base_addr
[
self
.
engine_id
],
device_id
=
self
.
device_id
,
device_id
=
self
.
device_id
,
num_blocks
=
self
.
num_blocks
,
num_blocks
=
self
.
num_blocks
,
block_lens
=
self
.
block_len_per_layer
,
block_lens
=
self
.
block_len_per_layer
,
attn_backend_name
=
self
.
backend_name
,
kv_cache_layout
=
self
.
kv_cache_layout
kv_cache_layout
=
self
.
kv_cache_layout
if
not
self
.
use_host_buffer
if
not
self
.
use_host_buffer
else
self
.
host_buffer_kv_cache_layout
,
else
self
.
host_buffer_kv_cache_layout
,
block_size
=
self
.
block_size
,
block_size
=
self
.
block_size
,
)
)
# Wrap metadata in payload with hash for defensive decoding
encoder
=
msgspec
.
msgpack
.
Encoder
()
self
.
xfer_handshake_metadata
=
NixlHandshakePayload
(
compatibility_hash
=
self
.
compat_hash
,
agent_metadata_bytes
=
encoder
.
encode
(
agent_metadata
),
)
def
register_local_xfer_handler
(
def
register_local_xfer_handler
(
self
,
self
,
...
@@ -1524,8 +1555,6 @@ class NixlConnectorWorker:
...
@@ -1524,8 +1555,6 @@ class NixlConnectorWorker:
remote_engine_id
=
nixl_agent_meta
.
engine_id
remote_engine_id
=
nixl_agent_meta
.
engine_id
assert
self
.
_tp_size
[
remote_engine_id
]
==
remote_tp_size
assert
self
.
_tp_size
[
remote_engine_id
]
==
remote_tp_size
# TODO We may eventually want to skip enforcing the same attn backend.
assert
nixl_agent_meta
.
attn_backend_name
==
self
.
backend_name
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
remote_engine_id
)
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
remote_engine_id
)
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
...
@@ -1818,9 +1847,7 @@ class NixlConnectorWorker:
...
@@ -1818,9 +1847,7 @@ class NixlConnectorWorker:
self
.
_reqs_to_send
.
pop
(
req_id
,
None
)
self
.
_reqs_to_send
.
pop
(
req_id
,
None
)
return
notified_req_ids
return
notified_req_ids
def
_pop_done_transfers
(
def
_pop_done_transfers
(
self
,
transfers
:
dict
[
str
,
list
[
int
]])
->
set
[
str
]:
self
,
transfers
:
dict
[
str
,
list
[
tuple
[
int
,
float
]]]
)
->
set
[
str
]:
"""
"""
Pop completed xfers by checking for DONE state.
Pop completed xfers by checking for DONE state.
Args:
Args:
...
@@ -1831,7 +1858,7 @@ class NixlConnectorWorker:
...
@@ -1831,7 +1858,7 @@ class NixlConnectorWorker:
done_req_ids
:
set
[
str
]
=
set
()
done_req_ids
:
set
[
str
]
=
set
()
for
req_id
,
handles
in
list
(
transfers
.
items
()):
for
req_id
,
handles
in
list
(
transfers
.
items
()):
in_progress
=
False
in_progress
=
False
for
handle
,
xfer_start_time
in
handles
:
for
handle
in
handles
:
try
:
try
:
xfer_state
=
self
.
nixl_wrapper
.
check_xfer_state
(
handle
)
xfer_state
=
self
.
nixl_wrapper
.
check_xfer_state
(
handle
)
if
xfer_state
==
"DONE"
:
if
xfer_state
==
"DONE"
:
...
@@ -1946,6 +1973,7 @@ class NixlConnectorWorker:
...
@@ -1946,6 +1973,7 @@ class NixlConnectorWorker:
self
.
_read_blocks
(
self
.
_read_blocks
(
request_id
=
req_id
,
request_id
=
req_id
,
dst_engine_id
=
meta
.
remote_engine_id
,
dst_engine_id
=
meta
.
remote_engine_id
,
remote_request_id
=
meta
.
remote_request_id
,
local_block_ids
=
meta
.
local_physical_block_ids
,
local_block_ids
=
meta
.
local_physical_block_ids
,
remote_block_ids
=
meta
.
remote_block_ids
,
remote_block_ids
=
meta
.
remote_block_ids
,
)
)
...
@@ -1956,6 +1984,7 @@ class NixlConnectorWorker:
...
@@ -1956,6 +1984,7 @@ class NixlConnectorWorker:
remote_block_ids
:
list
[
int
],
remote_block_ids
:
list
[
int
],
dst_engine_id
:
str
,
dst_engine_id
:
str
,
request_id
:
str
,
request_id
:
str
,
remote_request_id
:
str
,
):
):
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
dst_engine_id
)
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
dst_engine_id
)
if
block_size_ratio
>
1
:
if
block_size_ratio
>
1
:
...
@@ -1988,7 +2017,7 @@ class NixlConnectorWorker:
...
@@ -1988,7 +2017,7 @@ class NixlConnectorWorker:
# Number of D TP workers that will read from dst P. Propagate tp_ratio
# Number of D TP workers that will read from dst P. Propagate tp_ratio
# on notification so that dst worker can wait before freeing blocks.
# on notification so that dst worker can wait before freeing blocks.
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
dst_engine_id
)
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
dst_engine_id
)
notif_id
=
f
"
{
request_id
}
:
{
tp_ratio
}
"
.
encode
()
notif_id
=
f
"
{
remote_
request_id
}
:
{
tp_ratio
}
"
.
encode
()
# Full prefix cache hit: do not need to read remote blocks,
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
# just notify P worker that we have the blocks we need.
...
@@ -2096,7 +2125,7 @@ class NixlConnectorWorker:
...
@@ -2096,7 +2125,7 @@ class NixlConnectorWorker:
self
.
nixl_wrapper
.
transfer
(
handle
)
self
.
nixl_wrapper
.
transfer
(
handle
)
# Use handle to check completion in future step().
# Use handle to check completion in future step().
self
.
_recving_transfers
[
request_id
].
append
(
(
handle
,
time
.
perf_counter
())
)
self
.
_recving_transfers
[
request_id
].
append
(
handle
)
except
Exception
:
except
Exception
:
logger
.
exception
(
logger
.
exception
(
"NIXL transfer setup/initiation failed for request %s. "
"NIXL transfer setup/initiation failed for request %s. "
...
@@ -2227,7 +2256,7 @@ class NixlConnectorWorker:
...
@@ -2227,7 +2256,7 @@ class NixlConnectorWorker:
"""Shutdown the connector worker."""
"""Shutdown the connector worker."""
self
.
_handshake_initiation_executor
.
shutdown
(
wait
=
False
)
self
.
_handshake_initiation_executor
.
shutdown
(
wait
=
False
)
for
handles
in
self
.
_recving_transfers
.
values
():
for
handles
in
self
.
_recving_transfers
.
values
():
for
handle
,
_
in
handles
:
for
handle
in
handles
:
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
_recving_transfers
.
clear
()
self
.
_recving_transfers
.
clear
()
if
self
.
src_xfer_side_handle
:
if
self
.
src_xfer_side_handle
:
...
...
vllm/distributed/kv_transfer/kv_lookup_buffer/base.py
deleted
100644 → 0
View file @
ce888aa4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains a new class `KVLookupBufferBase` that allows developers to
think of KV cache operations as inserting new KV cache entries (`insert`)
into the lookup buffer and querying existing KV caches (`drop_select`)
from the lookup buffer.
This file also contains a new class `KVStoreBufferBase` that allows developers
to manage the KVCache buffer as a simple key-value storage buffer with basic
put/get operations.
These classes above are abstracted behind class `KVCacheBufferBase`.
"""
from
abc
import
ABC
,
abstractmethod
import
torch
class
KVCacheBufferBase
(
ABC
):
"""
Abstract base class for a KVCache buffer.
"""
@
abstractmethod
def
close
(
self
)
->
None
:
"""Close the buffer and release resources.
This method is responsible for cleaning up resources related to the
KVCache buffer when it is no longer needed.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise
NotImplementedError
class
KVLookupBufferBase
(
KVCacheBufferBase
):
"""
Abstract base class for a KVCache lookup buffer.
This class provides an abstraction for a key-value (KV) cache lookup buffer.
The key of the lookup buffer:
- input_tokens: token IDs of the request
- roi: a binary mask on top of input_tokens.
- Purpose of roi: Since KV cache may only be available for a subset of
tokens in the input (for example, when vLLM is connected to an external
KV cache service), roi specifies the subset of tokens that the KV cache
is associated with.
- NOTE: roi can be further extended to describe which part of KV the
current process is holding (each process may only hold a part of KV
due to TP and PP). This is not implemented for now.
The value of the lookup buffer:
- key: the key tensor in the KV cache
- value: the value tensor in the KV cache
- hidden: the final hidden state generated by model forwarding. This allows
vLLM to bypass further model forwarding by transmitting the hidden state.
"""
@
abstractmethod
def
insert
(
self
,
input_tokens
:
torch
.
Tensor
,
roi
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
hidden
:
torch
.
Tensor
,
)
->
None
:
"""Insert into the lookup buffer.
The functionality is similar to the following python statement
```
buffer[input_tokens, roi] = [key, value, hidden]
```
FIXME: in the future, we should only have two arguments, key and value,
where key is a tensor dict and value is a tensor dict.
FIXME: we should transmit both sampler outputs and the hidden states.
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
key (torch.Tensor): The key tensor in the KV cache.
value (torch.Tensor): The value tensor in the KV cache.
hidden (torch.Tensor): The final hidden state tensor generated
during model forwarding to bypass model
forwarding.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise
NotImplementedError
@
abstractmethod
def
drop_select
(
self
,
input_tokens
:
torch
.
Tensor
|
None
,
roi
:
torch
.
Tensor
|
None
)
->
list
[
torch
.
Tensor
|
None
]:
"""Select and *drop* KV cache entries from the lookup buffer.
The functionality is similar to the following python statements
```
ret = buffer.pop(input_tokens, roi)
return ret
```
If `input_tokens` and `roi` is `None`, it means selecting any of the
KV caches in the buffer, return, and remove it from the buffer, useful
when offloading KV cache to KV cache storage service.
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
Returns:
list[Optional[torch.Tensor]]: A list of tensors. Can be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise
NotImplementedError
class
KVStoreBufferBase
(
KVCacheBufferBase
):
"""
Abstract base class for a KVCache storage buffer with key-value semantics.
This class provides a simple key-value storage buffer abstract with basic
put/get operations, which enables flexible KVCache transfer granular
control.
The functionality is similar to a distributed key-value store, where:
- Key: A unique string identifier for the cached entry
- Value:
- Tensor to be stored and retrieved
- None (indicating deletion or empty value)
"""
@
abstractmethod
def
put
(
self
,
key
:
str
,
value
:
torch
.
Tensor
|
None
,
)
->
None
:
"""Store a key-value pair in the buffer.
Args:
key (str): Unique identifier for a tensor, this tensor could be the
key cache tensor, value cache tensor, or hidden state tensor
generated during model forwarding.
value (Optional[torch.Tensor]): Tensor to be stored.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise
NotImplementedError
@
abstractmethod
def
get
(
self
,
key
:
str
,
)
->
torch
.
Tensor
|
None
:
"""Retrieve a value from the buffer by key.
Args:
key (str): Unique identifier for a tensor, this tensor could be the
key cache tensor, value cache tensor, or hidden state tensor
generated during model forwarding.
Returns:
Optional[torch.Tensor]: Stored tensor if exists, None otherwise.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise
NotImplementedError
vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py
deleted
100644 → 0
View file @
ce888aa4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains a new class `MooncakeStore` that allows developers to
think of KV cache transfer operations as putting new KV cache entries
into a remote KVStore-based lookup buffer and getting existing KV caches
from this remote lookup buffer.
"""
import
json
import
os
from
dataclasses
import
dataclass
import
torch
from
safetensors.torch
import
load
as
safetensors_load
from
safetensors.torch
import
save
as
safetensors_save
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_lookup_buffer.base
import
KVStoreBufferBase
from
vllm.logger
import
init_logger
DEFAULT_GLOBAL_SEGMENT_SIZE
=
3355443200
# 3.125 GiB
DEFAULT_LOCAL_BUFFER_SIZE
=
1073741824
# 1.0 GiB
logger
=
init_logger
(
__name__
)
@
dataclass
class
MooncakeStoreConfig
:
local_hostname
:
str
metadata_server
:
str
global_segment_size
:
int
local_buffer_size
:
int
protocol
:
str
device_name
:
str
master_server_address
:
str
@
staticmethod
def
from_file
(
file_path
:
str
)
->
"MooncakeStoreConfig"
:
"""Load the config from a JSON file."""
with
open
(
file_path
)
as
fin
:
config
=
json
.
load
(
fin
)
return
MooncakeStoreConfig
(
local_hostname
=
config
.
get
(
"local_hostname"
),
metadata_server
=
config
.
get
(
"metadata_server"
),
global_segment_size
=
config
.
get
(
"global_segment_size"
,
DEFAULT_GLOBAL_SEGMENT_SIZE
),
local_buffer_size
=
config
.
get
(
"local_buffer_size"
,
DEFAULT_LOCAL_BUFFER_SIZE
),
protocol
=
config
.
get
(
"protocol"
,
"tcp"
),
device_name
=
config
.
get
(
"device_name"
,
""
),
master_server_address
=
config
.
get
(
"master_server_address"
),
)
@
staticmethod
def
load_from_env
()
->
"MooncakeStoreConfig"
:
"""Load config from a file specified in the environment variable."""
config_file_path
=
os
.
getenv
(
"MOONCAKE_CONFIG_PATH"
)
if
config_file_path
is
None
:
raise
ValueError
(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
)
return
MooncakeStoreConfig
.
from_file
(
config_file_path
)
class
MooncakeStore
(
KVStoreBufferBase
):
def
__init__
(
self
,
config
:
VllmConfig
,
):
try
:
from
mooncake.store
import
MooncakeDistributedStore
except
ImportError
as
e
:
raise
ImportError
(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md "
# noqa: E501
"to run vLLM with MooncakeConnector."
)
from
e
try
:
self
.
store
=
MooncakeDistributedStore
()
self
.
config
=
MooncakeStoreConfig
.
load_from_env
()
logger
.
info
(
"Mooncake Configuration loaded successfully."
)
self
.
store
.
setup
(
self
.
config
.
local_hostname
,
self
.
config
.
metadata_server
,
self
.
config
.
global_segment_size
,
self
.
config
.
local_buffer_size
,
self
.
config
.
protocol
,
self
.
config
.
device_name
,
self
.
config
.
master_server_address
,
)
except
ValueError
as
e
:
logger
.
error
(
"Configuration loading failed: %s"
,
e
)
raise
except
Exception
as
exc
:
logger
.
error
(
"An error occurred while loading the configuration: %s"
,
exc
)
raise
def
close
(
self
):
# MooncakeDistributedStore will automatically call the destructor, so
# it is unnecessary to close it manually.
pass
def
put
(
self
,
key
:
str
,
value
:
torch
.
Tensor
|
None
,
)
->
None
:
# A message queue needs to be introduced before making it asynchronous.
if
value
is
not
None
:
self
.
_put_impl
(
key
,
value
)
def
get
(
self
,
key
:
str
,
)
->
torch
.
Tensor
|
None
:
# A message queue needs to be introduced before making it asynchronous.
value
=
self
.
_get_impl
(
key
)
return
value
def
_put_impl
(
self
,
key
:
str
,
value
:
torch
.
Tensor
,
)
->
None
:
"""Put KVCache to Mooncake Store"""
device_id
=
value
.
device
.
index
if
value
.
device
.
type
==
"cuda"
else
-
1
device_tensor
=
torch
.
tensor
(
device_id
,
dtype
=
torch
.
int32
)
value_bytes
=
safetensors_save
({
"tensor"
:
value
,
"device_id"
:
device_tensor
})
try
:
self
.
store
.
put
(
key
,
value_bytes
)
except
TypeError
as
err
:
logger
.
error
(
"Failed to put value into Mooncake Store: %s"
,
err
)
raise
TypeError
(
"Mooncake Store Put Type Error."
)
from
err
def
_get_impl
(
self
,
key
:
str
,
)
->
torch
.
Tensor
|
None
:
"""Get KVCache from Mooncake Store"""
try
:
data
=
self
.
store
.
get
(
key
)
except
TypeError
as
err
:
logger
.
error
(
"Failed to get value from Mooncake Store: %s"
,
err
)
raise
TypeError
(
"Mooncake Store Get Type Error."
)
from
err
if
data
:
loaded_tensors
=
safetensors_load
(
data
)
tensor
=
loaded_tensors
[
"tensor"
]
device_id_tensor
=
loaded_tensors
[
"device_id"
]
device_id
=
int
(
device_id_tensor
.
item
())
device
=
(
torch
.
device
(
"cuda"
,
device_id
)
if
device_id
>=
0
else
torch
.
device
(
"cpu"
)
)
return
tensor
.
to
(
device
)
return
None
vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
deleted
100644 → 0
View file @
ce888aa4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Implements a distributed key-value (KV) cache transfer mechanism.
Key Features:
- Distributed KV cache transmission using PyNccl pipes.
- Non-blocking `insert`, blocking `drop_select`.
- Use CPU signal pipe to avoid racing condition
- Handles buffer size constraints and provide backpressure mechanism to
stop the prefill instance when the decode instance is slow.
"""
import
threading
from
collections
import
deque
import
torch
from
vllm.distributed.kv_transfer.kv_lookup_buffer.base
import
KVLookupBufferBase
from
vllm.distributed.kv_transfer.kv_pipe.base
import
KVPipeBase
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
SimpleBuffer
(
KVLookupBufferBase
):
def
__init__
(
self
,
signal_pipe
:
KVPipeBase
,
data_pipe
:
KVPipeBase
,
buffer_size_thresh
:
float
):
"""
signal_pipe: on CPU
NOTE: on-device recv will block all threads in the process, making the
KV cache producer unable to listen to new request while transmitting
KV cache. Luckily CPU recv only blocks the current thread so we use
CPU recv to listen to new request.
data_pipe: on device (e.g. GPU)
"""
self
.
buffer
:
deque
[
list
[
torch
.
Tensor
]]
=
deque
()
self
.
buffer_size
=
0
self
.
buffer_size_threshold
=
buffer_size_thresh
self
.
buffer_cv
=
threading
.
Condition
()
self
.
signal_pipe
=
signal_pipe
self
.
data_pipe
=
data_pipe
self
.
request_handling_thread
:
threading
.
Thread
|
None
=
None
self
.
normal_signal
=
torch
.
tensor
([
0
],
device
=
"cpu"
)
self
.
end_signal
=
None
def
_matches
(
self
,
tokens_roi_sender
:
list
[
torch
.
Tensor
],
tokens_roi_recver
:
list
[
torch
.
Tensor
],
):
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
# tokens_roi_recver: tokens and roi of the consumer (query)
tokens_sender
=
tokens_roi_sender
[
0
]
tokens_recver
=
tokens_roi_recver
[
0
]
roi_sender
=
tokens_roi_sender
[
1
]
roi_recver
=
tokens_roi_recver
[
1
]
if
tokens_recver
is
None
:
# consumer sends an empty request
# semantics: DROP SELECT * LIMIT 1
# so any of the data in the buffer can be drop-selected
return
True
# Assuming that roi is a binary mask on tokens
tokens_sender
=
tokens_sender
[
roi_sender
]
tokens_recver
=
tokens_recver
[
roi_recver
]
# simple common prefix matching
min_length
=
min
(
len
(
tokens_sender
),
len
(
tokens_recver
))
if
torch
.
allclose
(
tokens_sender
[:
min_length
],
tokens_recver
[:
min_length
]):
return
min_length
return
0
def
_send_tensor_and_dec_size
(
self
,
tensor
:
torch
.
Tensor
|
None
)
->
None
:
assert
tensor
is
not
None
,
"Use self.data_pipe.send(None) instead"
self
.
buffer_size
-=
tensor
.
element_size
()
*
tensor
.
numel
()
if
tensor
.
dtype
==
torch
.
bool
:
tensor
=
tensor
.
float
()
self
.
data_pipe
.
send_tensor
(
tensor
)
def
_get_element_size
(
self
,
data
:
list
|
torch
.
Tensor
|
None
):
if
isinstance
(
data
,
torch
.
Tensor
):
return
data
.
element_size
()
*
data
.
numel
()
if
not
data
:
# cannot perform `not data` on a tensor
# so this check needs to go after the check above
return
0
raise
AssertionError
(
f
"Unknown data type
{
type
(
data
)
}
"
)
def
_add_to_buffer
(
self
,
input_tokens
:
torch
.
Tensor
,
roi
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
hidden
:
torch
.
Tensor
,
):
if
isinstance
(
input_tokens
,
torch
.
Tensor
):
input_tokens
=
input_tokens
.
clone
()
if
isinstance
(
roi
,
torch
.
Tensor
):
roi
=
roi
.
clone
()
if
isinstance
(
key
,
torch
.
Tensor
):
key
=
key
.
clone
()
if
isinstance
(
value
,
torch
.
Tensor
):
value
=
value
.
clone
()
if
isinstance
(
hidden
,
torch
.
Tensor
):
hidden
=
hidden
.
clone
()
buffer_item
=
[
input_tokens
,
roi
,
key
,
value
,
hidden
]
data_size
=
sum
([
self
.
_get_element_size
(
data
)
for
data
in
buffer_item
])
with
self
.
buffer_cv
:
if
self
.
buffer_size
+
data_size
>
self
.
buffer_size_threshold
:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger
.
debug
(
"KV transfer buffer is full. Handling..."
)
while
self
.
buffer_size
+
data_size
>
self
.
buffer_size_threshold
:
self
.
buffer_cv
.
wait
()
self
.
buffer_size
+=
data_size
self
.
buffer
.
append
(
buffer_item
)
self
.
buffer_cv
.
notify
()
def
_is_end_signal
(
self
,
signal
):
return
signal
is
None
def
drop_select_handler
(
self
):
try
:
while
True
:
signal
=
self
.
signal_pipe
.
recv_tensor
()
if
self
.
_is_end_signal
(
signal
):
logger
.
info
(
"Received end signal!"
)
break
input_tokens
=
self
.
data_pipe
.
recv_tensor
()
roi
=
self
.
data_pipe
.
recv_tensor
()
assert
roi
is
not
None
,
(
"Please provide the roi when sending drop-select request"
)
roi
=
roi
>
0.5
tokens_roi_recver
=
[
input_tokens
,
roi
]
def
is_buffer_available
(
tokens_roi_recver
:
list
[
torch
.
Tensor
],
)
->
bool
:
# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent.
for
_
in
range
(
len
(
self
.
buffer
)):
if
self
.
_matches
(
self
.
buffer
[
0
],
tokens_roi_recver
)
>
0
:
return
True
# rotate the element we just accessed to the end
self
.
buffer
.
rotate
(
-
1
)
return
False
with
self
.
buffer_cv
:
while
not
is_buffer_available
(
tokens_roi_recver
):
logger
.
debug
(
"KV transfer buffer is not available. Waiting..."
)
self
.
buffer_cv
.
wait
()
# need to clone the tensor
# in case the tensor is freed before sending finishes
matched_item
=
self
.
buffer
.
popleft
()
for
tensor
in
matched_item
:
self
.
_send_tensor_and_dec_size
(
tensor
)
self
.
buffer_cv
.
notify
()
except
RuntimeError
as
e
:
if
"Connection closed by peer"
not
in
str
(
e
):
raise
e
logger
.
debug
(
"Closing drop_select_handler"
)
def
drop_select
(
self
,
input_tokens
:
torch
.
Tensor
|
None
,
roi
:
torch
.
Tensor
|
None
)
->
list
[
torch
.
Tensor
|
None
]:
assert
self
.
request_handling_thread
is
None
,
(
"drop_select should be called by the KV cache consumer "
"(e.g. the decode vLLM instance)"
)
if
isinstance
(
input_tokens
,
torch
.
Tensor
):
input_tokens
=
input_tokens
.
clone
()
if
isinstance
(
roi
,
torch
.
Tensor
):
roi
=
roi
.
clone
().
float
()
self
.
signal_pipe
.
send_tensor
(
self
.
normal_signal
)
self
.
data_pipe
.
send_tensor
(
input_tokens
)
self
.
data_pipe
.
send_tensor
(
roi
)
input_tokens
=
self
.
data_pipe
.
recv_tensor
()
roi
=
self
.
data_pipe
.
recv_tensor
()
if
roi
is
not
None
:
# convert from float tensor to bool tensor
# as PyNccl does not support sending bool tensor
roi
=
roi
>
0.5
key
=
self
.
data_pipe
.
recv_tensor
()
value
=
self
.
data_pipe
.
recv_tensor
()
hidden
=
self
.
data_pipe
.
recv_tensor
()
return
[
input_tokens
,
roi
,
key
,
value
,
hidden
]
def
insert
(
self
,
input_tokens
:
torch
.
Tensor
,
roi
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
hidden
:
torch
.
Tensor
,
)
->
None
:
self
.
_add_to_buffer
(
input_tokens
,
roi
,
key
,
value
,
hidden
)
# when calling the insert, the current process is a sender
# need to launch the request handler and start listening to request.
if
self
.
request_handling_thread
is
None
:
self
.
request_handling_thread
=
threading
.
Thread
(
target
=
self
.
drop_select_handler
)
self
.
request_handling_thread
.
start
()
def
close
(
self
):
if
(
hasattr
(
self
,
"request_handling_thread"
)
and
self
.
request_handling_thread
is
not
None
):
self
.
request_handling_thread
.
join
()
else
:
# TODO: have a explicit close signal and have a explicit way to
# check if it's requester
self
.
signal_pipe
.
send_tensor
(
self
.
end_signal
)
vllm/distributed/kv_transfer/kv_pipe/base.py
deleted
100644 → 0
View file @
ce888aa4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file defines an interface `KVPipeBase`
that provides an abstraction for sending and receiving tensors, or None, via
distributed communications.
All classes instantiated from this interface are assumed to be a FIFO pipe.
If your distributed communication platform already supports key-value lookup,
you can bypass this interface and directly start from `kv_lookup_buffer`.
"""
from
abc
import
ABC
,
abstractmethod
import
torch
class
KVPipeBase
(
ABC
):
"""
This class provides an interface for sending and receiving tensors, or
None, by distributed communications.
"""
@
abstractmethod
def
send_tensor
(
self
,
tensor
:
torch
.
Tensor
|
None
)
->
None
:
"""Send a tensor, or None, via the pipe.
Need to support sending None -- important for error handling.
TODO: add a `key` argument so that we can use traditional
key-value database as the distributed communication mechanism behind
the pipe.
Args:
tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise
NotImplementedError
@
abstractmethod
def
recv_tensor
(
self
)
->
torch
.
Tensor
|
None
:
"""Receive a tensor (can be None) from the pipeline.
Returns:
Optional[torch.Tensor]: The tensor received from the pipeline. Can
be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise
NotImplementedError
@
abstractmethod
def
close
(
self
)
->
None
:
"""Close the pipeline and release resources.
This method is responsible for closing the communication pipeline
and releasing any resources associated with it.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise
NotImplementedError
Prev
1
…
16
17
18
19
20
21
22
23
24
…
34
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment