Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f4f4e7ef
Unverified
Commit
f4f4e7ef
authored
Aug 04, 2025
by
lkchen
Committed by
GitHub
Aug 04, 2025
Browse files
[V0 deprecation][P/D] Deprecate v0 `KVConnectorBase` code (1/2) (#21785)
Signed-off-by:
Linkun Chen
<
github@lkchen.net
>
parent
5ea71ff4
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
31 additions
and
1040 deletions
+31
-1040
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+0
-1
tests/kv_transfer/test_disagg.py
tests/kv_transfer/test_disagg.py
+0
-120
vllm/distributed/kv_transfer/kv_connector/base.py
vllm/distributed/kv_transfer/kv_connector/base.py
+4
-136
vllm/distributed/kv_transfer/kv_connector/factory.py
vllm/distributed/kv_transfer/kv_connector/factory.py
+13
-55
vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py
...distributed/kv_transfer/kv_connector/lmcache_connector.py
+0
-99
vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py
...uted/kv_transfer/kv_connector/mooncake_store_connector.py
+0
-203
vllm/distributed/kv_transfer/kv_connector/simple_connector.py
.../distributed/kv_transfer/kv_connector/simple_connector.py
+0
-329
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+4
-5
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
...istributed/kv_transfer/kv_connector/v1/multi_connector.py
+4
-4
vllm/distributed/kv_transfer/kv_connector_agent.py
vllm/distributed/kv_transfer/kv_connector_agent.py
+0
-77
vllm/distributed/kv_transfer/kv_transfer_state.py
vllm/distributed/kv_transfer/kv_transfer_state.py
+2
-7
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+1
-1
vllm/v1/worker/kv_connector_model_runner_mixin.py
vllm/v1/worker/kv_connector_model_runner_mixin.py
+3
-3
No files found.
.buildkite/test-pipeline.yaml
View file @
f4f4e7ef
...
...
@@ -749,7 +749,6 @@ steps:
# this test fails consistently.
# TODO: investigate and fix
-
VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
-
VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
-
CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
-
pytest -v -s models/multimodal/generation/test_maverick.py
...
...
tests/kv_transfer/test_disagg.py
deleted
100644 → 0
View file @
5ea71ff4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
subprocess
import
sys
import
time
from
subprocess
import
Popen
import
pytest
import
requests
import
torch
# Fixture to set up environment variables and teardown servers after tests
@
pytest
.
fixture
(
scope
=
"module"
,
autouse
=
True
)
def
setup_servers
():
if
torch
.
cuda
.
device_count
()
<
2
:
pytest
.
skip
(
"Skipping test: fewer than 2 GPUs available"
)
# Set up environment variables
VLLM_HOST_IP
=
subprocess
.
check_output
(
"hostname -I | awk '{print $1}'"
,
shell
=
True
).
decode
().
strip
()
os
.
environ
[
"VLLM_HOST_IP"
]
=
VLLM_HOST_IP
# Start prefill instance
prefill_cmd
=
[
sys
.
executable
,
"-m"
,
"vllm.entrypoints.openai.api_server"
,
"--model"
,
"meta-llama/Llama-3.2-1B-Instruct"
,
"--port"
,
"8100"
,
"--gpu-memory-utilization"
,
"0.5"
,
"--max-model-len"
,
"1000"
,
"--kv-transfer-config"
,
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'
\
'"kv_rank":0,"kv_parallel_size":2}'
,
]
prefill_env
=
os
.
environ
.
copy
()
prefill_env
[
"CUDA_VISIBLE_DEVICES"
]
=
"0"
prefill_proc
=
Popen
(
prefill_cmd
,
env
=
prefill_env
)
# Start decode instance
decode_cmd
=
[
sys
.
executable
,
"-m"
,
"vllm.entrypoints.openai.api_server"
,
"--model"
,
"meta-llama/Llama-3.2-1B-Instruct"
,
"--port"
,
"8200"
,
"--gpu-memory-utilization"
,
"0.5"
,
"--max-model-len"
,
"1000"
,
"--kv-transfer-config"
,
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'
\
'"kv_rank":1,"kv_parallel_size":2}'
,
]
decode_env
=
os
.
environ
.
copy
()
decode_env
[
"CUDA_VISIBLE_DEVICES"
]
=
"1"
decode_proc
=
Popen
(
decode_cmd
,
env
=
decode_env
)
# Wait for servers to be ready
assert
wait_for_server
(
8100
),
"Prefill server did not start in time"
assert
wait_for_server
(
8200
),
"Decode server did not start in time"
# Yield to the test function and handle teardown after tests
yield
# Cleanup: kill the processes
prefill_proc
.
terminate
()
decode_proc
.
terminate
()
# Additional cleanup if needed
prefill_proc
.
wait
()
decode_proc
.
wait
()
# Helper function to wait for server
def
wait_for_server
(
port
,
timeout
=
240
):
start_time
=
time
.
time
()
while
time
.
time
()
-
start_time
<
timeout
:
try
:
response
=
requests
.
get
(
f
"http://localhost:
{
port
}
/v1/completions"
)
if
response
.
status_code
in
[
200
,
405
]:
return
True
except
requests
.
ConnectionError
:
time
.
sleep
(
1
)
return
False
# Test function to send curl requests and validate responses
@
pytest
.
mark
.
parametrize
(
"prompt"
,
[
"San Francisco is a"
,
"Santa Clara is a"
])
def
test_disaggregated_prefilling
(
prompt
):
# Send to prefill
response
=
requests
.
post
(
"http://localhost:8100/v1/completions"
,
headers
=
{
"Content-Type"
:
"application/json"
},
json
=
{
"model"
:
"meta-llama/Llama-3.2-1B-Instruct"
,
"prompt"
:
prompt
,
"max_tokens"
:
1
,
"temperature"
:
0
})
assert
response
.
status_code
==
200
# Send to decode
response
=
requests
.
post
(
"http://localhost:8200/v1/completions"
,
headers
=
{
"Content-Type"
:
"application/json"
},
json
=
{
"model"
:
"meta-llama/Llama-3.2-1B-Instruct"
,
"prompt"
:
prompt
,
"max_tokens"
:
10
,
"temperature"
:
0
})
assert
response
.
status_code
==
200
vllm/distributed/kv_transfer/kv_connector/base.py
View file @
f4f4e7ef
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KVConnectorBase Class for Distributed KV Cache & Hidden State communication
The class provides two primary abstract methods:
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
"""
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
torch
"""Defines the base type for KV cache connectors."""
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorBase_V1
from
vllm.sequence
import
IntermediateTensors
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
class
KVConnectorBase
(
ABC
):
"""
Abstract base class for a KV connector.
The class provides two primary abstract methods:
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
"""
@
abstractmethod
def
__init__
(
self
,
rank
:
int
,
local_rank
:
int
,
config
:
"VllmConfig"
,
):
raise
NotImplementedError
@
abstractmethod
def
close
(
self
)
->
None
:
"""Close the buffer and release resources.
This method is responsible for cleaning up resources related to the
connector when it is no longer needed.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise
NotImplementedError
@
abstractmethod
def
send_kv_caches_and_hidden_states
(
self
,
model_executable
:
torch
.
nn
.
Module
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
kv_caches
:
list
[
torch
.
Tensor
],
hidden_or_intermediate_states
:
Union
[
torch
.
Tensor
,
IntermediateTensors
],
)
->
None
:
"""
Send KV caches and hidden states to the connector.
This method processes the input tokens, KV caches, and
hidden/intermediate states for a given model and sends the data to the
decode instance.
Args:
model_executable (torch.nn.Module): The model executable containing
start and end layer information.
model_input (ModelInputForGPUWithSamplingMetadata): The input
metadata from vLLM.
kv_caches (list[torch.Tensor]): List of KV caches (keys and values)
for each layer.
hidden_or_intermediate_states (Union[torch.Tensor,
IntermediateTensors]):
The hidden or intermediate states associated with the tokens.
Returns:
None
"""
raise
NotImplementedError
@
abstractmethod
def
recv_kv_caches_and_hidden_states
(
self
,
model_executable
:
torch
.
nn
.
Module
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
kv_caches
:
list
[
torch
.
Tensor
]
)
->
tuple
[
Union
[
torch
.
Tensor
,
IntermediateTensors
],
bool
,
"ModelInputForGPUWithSamplingMetadata"
]:
"""
Receive KV caches and hidden states from the connector.
This method attempts to retrieve KV caches and hidden states for input
tokens. If all required KV caches and hidden states are received, it
will bypass model input, else it will fall back to normal vLLM model
forwarding.
Args:
model_executable (torch.nn.Module):
The model executable from vLLM modelrunner.
model_input (ModelInputForGPUWithSamplingMetadata):
The model input from vLLM modelrunner.
kv_caches (list[torch.Tensor]):
List of KV caches for each layer.
Returns:
- hidden_or_intermediate_states (torch.Tensor or
IntermediateTensors):
Concatenated hidden states if all required data is retrieved,
otherwise `None`.
- bypass_model_exec (bool):
Indicates whether the model execution can be skipped (True) or
needs to be redone (False).
- model_input (ModelInputForGPUWithSamplingMetadata):
Optionally adjusted input metadata for re-execution when
`bypass_model_exec=False`.
"""
raise
NotImplementedError
@
classmethod
def
get_required_kvcache_layout
(
cls
,
vllm_config
:
"VllmConfig"
)
->
Optional
[
str
]:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
return
None
KVConnectorBase
=
KVConnectorBase_V1
KVConnectorBaseType
=
KVConnectorBase_V1
KVConnectorBaseType
=
Union
[
KVConnectorBase
,
KVConnectorBase
_V1
]
__all__
=
[
"
KVConnectorBase
"
,
"
KVConnectorBase
Type"
]
vllm/distributed/kv_transfer/kv_connector/factory.py
View file @
f4f4e7ef
...
...
@@ -5,14 +5,10 @@ import importlib
from
typing
import
TYPE_CHECKING
,
Callable
import
vllm.envs
as
envs
from
vllm.config
import
KVTransferConfig
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBaseType
from
vllm.distributed.kv_transfer.kv_connector.v1
import
(
KVConnectorBase_V1
,
KVConnectorRole
)
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorRole
from
vllm.logger
import
init_logger
from
.base
import
KVConnectorBase
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
...
...
@@ -20,7 +16,7 @@ logger = init_logger(__name__)
class
KVConnectorFactory
:
_registry
:
dict
[
str
,
Callable
[[],
type
[
KVConnectorBase
Type
]]]
=
{}
_registry
:
dict
[
str
,
Callable
[[],
type
[
KVConnectorBase
]]]
=
{}
@
classmethod
def
register_connector
(
cls
,
name
:
str
,
module_path
:
str
,
...
...
@@ -29,28 +25,23 @@ class KVConnectorFactory:
if
name
in
cls
.
_registry
:
raise
ValueError
(
f
"Connector '
{
name
}
' is already registered."
)
def
loader
()
->
type
[
KVConnectorBase
Type
]:
def
loader
()
->
type
[
KVConnectorBase
]:
module
=
importlib
.
import_module
(
module_path
)
return
getattr
(
module
,
class_name
)
cls
.
_registry
[
name
]
=
loader
@
classmethod
def
create_connector_v0
(
cls
,
rank
:
int
,
local_rank
:
int
,
config
:
"VllmConfig"
)
->
KVConnectorBase
:
if
envs
.
VLLM_USE_V1
:
raise
ValueError
(
"Attempting to initialize a V0 Connector, "
def
create_connector
(
cls
,
config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
)
->
KVConnectorBase
:
if
not
envs
.
VLLM_USE_V1
:
raise
ValueError
(
"Attempting to initialize a V1 Connector, "
f
"but found
{
envs
.
VLLM_USE_V1
=
}
"
)
connector_cls
=
cls
.
get_connector_class
(
config
.
kv_transfer_config
)
assert
issubclass
(
connector_cls
,
KVConnectorBase
)
return
connector_cls
(
rank
,
local_rank
,
config
)
@
classmethod
def
get_connector_class
(
cls
,
kv_transfer_config
:
"KVTransferConfig"
)
->
type
[
KVConnectorBaseType
]:
"""Get the connector class by name."""
kv_transfer_config
=
config
.
kv_transfer_config
connector_name
=
kv_transfer_config
.
kv_connector
if
connector_name
in
cls
.
_registry
:
connector_cls
=
cls
.
_registry
[
connector_name
]()
...
...
@@ -61,21 +52,7 @@ class KVConnectorFactory:
f
"Unsupported connector type:
{
connector_name
}
"
)
connector_module
=
importlib
.
import_module
(
connector_module_path
)
connector_cls
=
getattr
(
connector_module
,
connector_name
)
return
connector_cls
@
classmethod
def
create_connector_v1
(
cls
,
config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
)
->
KVConnectorBase_V1
:
if
not
envs
.
VLLM_USE_V1
:
raise
ValueError
(
"Attempting to initialize a V1 Connector, "
f
"but found
{
envs
.
VLLM_USE_V1
=
}
"
)
kv_transfer_config
=
config
.
kv_transfer_config
connector_cls
=
cls
.
get_connector_class
(
kv_transfer_config
)
assert
issubclass
(
connector_cls
,
KVConnectorBase_V1
)
assert
issubclass
(
connector_cls
,
KVConnectorBase
)
logger
.
info
(
"Creating v1 connector with name: %s and engine_id: %s"
,
connector_cls
.
__name__
,
kv_transfer_config
.
engine_id
)
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
...
...
@@ -92,25 +69,6 @@ class KVConnectorFactory:
# Register various connectors here.
# The registration should not be done in each individual file, as we want to
# only load the files corresponding to the current connector.
KVConnectorFactory
.
register_connector
(
"PyNcclConnector"
,
"vllm.distributed.kv_transfer.kv_connector.simple_connector"
,
"SimpleConnector"
)
KVConnectorFactory
.
register_connector
(
"MooncakeConnector"
,
"vllm.distributed.kv_transfer.kv_connector.simple_connector"
,
"SimpleConnector"
)
KVConnectorFactory
.
register_connector
(
"LMCacheConnector"
,
"vllm.distributed.kv_transfer.kv_connector.lmcache_connector"
,
"LMCacheConnector"
)
KVConnectorFactory
.
register_connector
(
"MooncakeStoreConnector"
,
"vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector"
,
"MooncakeStoreConnector"
)
KVConnectorFactory
.
register_connector
(
"SharedStorageConnector"
,
...
...
vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py
deleted
100644 → 0
View file @
5ea71ff4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
LMCache KV Cache Connector for Distributed Machine Learning Inference
The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker
(KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache;
(2) offload and share KV caches.
"""
from
typing
import
TYPE_CHECKING
,
Union
import
torch
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.logger
import
init_logger
from
vllm.sequence
import
IntermediateTensors
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
logger
=
init_logger
(
__name__
)
class
LMCacheConnector
(
KVConnectorBase
):
def
__init__
(
self
,
rank
:
int
,
local_rank
:
int
,
config
:
VllmConfig
,
):
self
.
transfer_config
=
config
.
kv_transfer_config
self
.
vllm_config
=
config
from
lmcache.experimental.cache_engine
import
LMCacheEngineBuilder
from
lmcache.integration.vllm.utils
import
ENGINE_NAME
from
lmcache.integration.vllm.vllm_adapter
import
(
RetrieveStatus
,
StoreStatus
,
init_lmcache_engine
,
lmcache_retrieve_kv
,
lmcache_should_retrieve
,
lmcache_should_store
,
lmcache_store_kv
)
logger
.
info
(
"Initializing LMCacheConfig under kv_transfer_config %s"
,
self
.
transfer_config
)
# TODO (Jiayi): Find model_config, parallel_config, and cache_config
self
.
engine
=
init_lmcache_engine
(
config
.
model_config
,
config
.
parallel_config
,
config
.
cache_config
)
self
.
lmcache_engine_name
=
ENGINE_NAME
self
.
lmcache_engine_builder
=
LMCacheEngineBuilder
self
.
model_config
=
config
.
model_config
self
.
parallel_config
=
config
.
parallel_config
self
.
cache_config
=
config
.
cache_config
self
.
lmcache_retrieve_kv
=
lmcache_retrieve_kv
self
.
lmcache_store_kv
=
lmcache_store_kv
self
.
lmcache_should_retrieve
=
lmcache_should_retrieve
self
.
lmcache_should_store
=
lmcache_should_store
self
.
store_status
=
StoreStatus
self
.
retrieve_status
=
RetrieveStatus
def
recv_kv_caches_and_hidden_states
(
self
,
model_executable
:
torch
.
nn
.
Module
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
kv_caches
:
list
[
torch
.
Tensor
]
)
->
tuple
[
Union
[
torch
.
Tensor
,
IntermediateTensors
],
bool
,
"ModelInputForGPUWithSamplingMetadata"
]:
retrieve_status
=
self
.
lmcache_should_retrieve
(
model_input
)
model_input
,
bypass_model_exec
,
hidden_or_intermediate_states
=
\
self
.
lmcache_retrieve_kv
(
model_executable
,
model_input
,
self
.
cache_config
,
kv_caches
,
retrieve_status
)
return
hidden_or_intermediate_states
,
bypass_model_exec
,
model_input
def
send_kv_caches_and_hidden_states
(
self
,
model_executable
:
torch
.
nn
.
Module
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
kv_caches
:
list
[
torch
.
Tensor
],
hidden_or_intermediate_states
:
Union
[
torch
.
Tensor
,
IntermediateTensors
],
)
->
None
:
store_status
=
self
.
lmcache_should_store
(
model_input
)
self
.
lmcache_store_kv
(
self
.
model_config
,
self
.
parallel_config
,
self
.
cache_config
,
model_executable
,
model_input
,
kv_caches
,
store_status
,
)
def
close
(
self
):
self
.
lmcache_engine_builder
.
destroy
(
self
.
lmcache_engine_name
)
vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py
deleted
100644 → 0
View file @
5ea71ff4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
MooncakeStore Connector for Distributed Machine Learning Inference
The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
(KV cache producer) and decode vLLM workers (KV cache consumer) using a
database-style KVStore.
"""
import
hashlib
from
typing
import
TYPE_CHECKING
,
Union
import
torch
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
model_aware_kv_ops_helper
as
kv_helper
)
from
vllm.logger
import
init_logger
from
vllm.sequence
import
IntermediateTensors
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
logger
=
init_logger
(
__name__
)
class
MooncakeStoreConnector
(
KVConnectorBase
):
def
__init__
(
self
,
rank
:
int
,
local_rank
:
int
,
config
:
VllmConfig
,
):
self
.
kv_transfer_config
=
config
.
kv_transfer_config
self
.
kv_helper
=
kv_helper
(
config
)
self
.
local_tp_rank
=
local_rank
# Init kv_store
if
self
.
kv_transfer_config
.
kv_connector
==
"MooncakeStoreConnector"
:
# Check if MOONCAKE_CONFIG_PATH is set
import
os
use_mooncake_store
=
os
.
getenv
(
'MOONCAKE_CONFIG_PATH'
)
is
not
None
if
not
use_mooncake_store
:
raise
ValueError
(
"To use MooncakeStoreConnector, you need to pass the ENV: "
"'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'."
)
else
:
from
vllm.distributed.kv_transfer.kv_lookup_buffer.mooncake_store
import
(
# noqa: E501
MooncakeStore
)
logger
.
info
(
"Initializing KVStoreConnector under kv_transfer_config %s"
,
self
.
kv_transfer_config
)
self
.
kv_store
=
MooncakeStore
(
config
)
else
:
logger
.
error
(
"Can not find %s"
,
self
.
kv_transfer_config
.
kv_connector
)
assert
self
.
kv_store
is
not
None
def
close
(
self
)
->
None
:
"""Close the buffer and release resources.
This method is responsible for cleaning up resources related to the
connector when it is no longer needed.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
self
.
kv_store
.
close
()
def
send_kv_caches_and_hidden_states
(
self
,
model_executable
:
torch
.
nn
.
Module
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
kv_caches
:
list
[
torch
.
Tensor
],
hidden_or_intermediate_states
:
Union
[
torch
.
Tensor
,
IntermediateTensors
],
)
->
None
:
input_tokens_tensor
=
model_input
.
input_tokens
seq_lens
=
model_input
.
attn_metadata
.
seq_lens
slot_mapping_flat
=
model_input
.
attn_metadata
.
slot_mapping
.
flatten
()
start_layer
=
model_executable
.
model
.
start_layer
end_layer
=
model_executable
.
model
.
end_layer
num_heads
,
head_size
=
self
.
kv_helper
.
get_model_args
(
model_executable
)
for
idx
,
slen
in
enumerate
(
seq_lens
):
start_pos
=
sum
(
seq_lens
[:
idx
])
end_pos
=
start_pos
+
slen
current_tokens
=
input_tokens_tensor
[
start_pos
:
end_pos
]
store_key_prefix
=
self
.
tensor_hash
(
current_tokens
)
keys
,
values
=
[],
[]
for
layer_id
in
range
(
start_layer
,
end_layer
):
kv_cache
=
kv_caches
[
layer_id
-
start_layer
]
key_cache
,
value_cache
=
self
.
kv_helper
.
get_kv_from_cache
(
kv_cache
,
num_heads
,
head_size
)
current_slot_mapping
=
slot_mapping_flat
[
start_pos
:
end_pos
]
keys
.
append
(
key_cache
[
current_slot_mapping
].
unsqueeze
(
0
))
values
.
append
(
value_cache
[
current_slot_mapping
].
unsqueeze
(
0
))
keys
=
torch
.
cat
(
keys
,
dim
=
0
)
values
=
torch
.
cat
(
values
,
dim
=
0
)
kvcache_to_sent
=
torch
.
stack
((
keys
,
values
),
dim
=
0
)
store_kvcache_key
=
f
"
{
store_key_prefix
}
_
{
self
.
local_tp_rank
}
"
self
.
kv_store
.
put
(
store_kvcache_key
,
kvcache_to_sent
)
hidden_key
=
f
"
{
store_key_prefix
}
_hidden_
{
self
.
local_tp_rank
}
"
self
.
kv_store
.
put
(
hidden_key
,
hidden_or_intermediate_states
[
start_pos
:
end_pos
])
logger
.
debug
(
"[rank%d]: KV send DONE."
,
torch
.
distributed
.
get_rank
())
def
recv_kv_caches_and_hidden_states
(
self
,
model_executable
:
torch
.
nn
.
Module
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
kv_caches
:
list
[
torch
.
Tensor
]
)
->
tuple
[
Union
[
torch
.
Tensor
,
IntermediateTensors
],
bool
,
"ModelInputForGPUWithSamplingMetadata"
]:
bypass_model_exec
=
True
input_tokens_tensor
=
model_input
.
input_tokens
seq_lens
=
model_input
.
attn_metadata
.
seq_lens
num_prefill_tokens
=
model_input
.
attn_metadata
.
num_prefill_tokens
slot_mapping
=
model_input
.
attn_metadata
.
slot_mapping
.
flatten
()
start_layer
=
model_executable
.
model
.
start_layer
end_layer
=
model_executable
.
model
.
end_layer
hidden_or_intermediate_states_for_one_req
=
[]
for
idx
,
slen
in
enumerate
(
seq_lens
):
start_pos
=
sum
(
seq_lens
[:
idx
])
end_pos
=
start_pos
+
slen
if
start_pos
>=
num_prefill_tokens
:
# This can happen during inflight batching. See:
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger
.
warning
(
"You should set --enable_chunked_prefill=False "
"and --max_num_batched_tokens "
"should be equal to max_seq_len_to_capture"
)
bypass_model_exec
=
False
assert
start_pos
==
num_prefill_tokens
break
current_tokens
=
input_tokens_tensor
[
start_pos
:
end_pos
]
# get roi for current seq
load_key_prefix
=
self
.
tensor_hash
(
current_tokens
)
load_kvcache_key
=
f
"
{
load_key_prefix
}
_
{
self
.
local_tp_rank
}
"
remote_kv
=
self
.
kv_store
.
get
(
load_kvcache_key
)
hidden_key
=
f
"
{
load_key_prefix
}
_hidden_
{
self
.
local_tp_rank
}
"
hidden
=
self
.
kv_store
.
get
(
hidden_key
)
if
remote_kv
is
None
or
hidden
is
None
:
# didn't find any match.
bypass_model_exec
=
False
continue
num_computed_tokens
=
current_tokens
.
shape
[
0
]
# update the end position based on how many tokens are cached.
end_pos
=
start_pos
+
num_computed_tokens
# call self.kv_store to get kv layer by layer
for
layer_id
in
range
(
start_layer
,
end_layer
):
layer
=
model_executable
.
model
.
layers
[
layer_id
]
# get kvcache object
kv_cache
=
kv_caches
[
layer_id
-
start_layer
]
# get remote kvcache
remote_k
,
remote_v
=
remote_kv
[
0
][
layer_id
],
remote_kv
[
1
][
layer_id
]
self
.
kv_helper
.
put_kv_to_cache
(
model_executable
,
remote_k
,
remote_v
,
layer
,
kv_cache
,
slot_mapping
,
start_pos
,
end_pos
)
hidden_or_intermediate_states_for_one_req
.
append
(
hidden
)
if
not
bypass_model_exec
:
logger
.
warning
(
"[rank%d]: Failed to receive all KVs and hidden "
"states, redo model forwarding."
,
torch
.
distributed
.
get_rank
())
hidden_or_intermediate_states
=
None
else
:
logger
.
debug
(
"[rank%d]: Successfully received all KVs and hidden "
"states, skip model forwarding."
,
torch
.
distributed
.
get_rank
())
hidden_or_intermediate_states
=
torch
.
cat
(
hidden_or_intermediate_states_for_one_req
,
dim
=
0
)
return
hidden_or_intermediate_states
,
bypass_model_exec
,
model_input
@
staticmethod
def
tensor_hash
(
tensor
:
torch
.
Tensor
)
->
int
:
"""Calculate the hash value of the tensor."""
tensor_bytes
=
tensor
.
clone
().
detach
().
cpu
().
numpy
().
tobytes
()
hash_object
=
hashlib
.
blake2b
(
tensor_bytes
)
hash_hex
=
hash_object
.
hexdigest
()
return
int
(
hash_hex
[:
16
],
16
)
vllm/distributed/kv_transfer/kv_connector/simple_connector.py
deleted
100644 → 0
View file @
5ea71ff4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Simple KV Cache Connector for Distributed Machine Learning Inference
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
MooncakePipe.
But the logic can be extended to support other pipe and lookup buffer.
"""
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
torch
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
model_aware_kv_ops_helper
as
kv_helper
)
from
vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer
import
(
SimpleBuffer
)
from
vllm.logger
import
init_logger
from
vllm.sequence
import
IntermediateTensors
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
logger
=
init_logger
(
__name__
)
class
SimpleConnector
(
KVConnectorBase
):
def
__init__
(
self
,
rank
:
int
,
local_rank
:
int
,
config
:
VllmConfig
,
):
self
.
config
=
config
.
kv_transfer_config
self
.
kv_helper
=
kv_helper
(
config
)
if
self
.
config
.
kv_connector
==
"PyNcclConnector"
:
from
vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe
import
(
PyNcclPipe
)
logger
.
info
(
"Initializing PyNcclConfig under kv_transfer_config %s"
,
self
.
config
)
elif
self
.
config
.
kv_connector
==
"MooncakeConnector"
:
# Check if MOONCAKE_CONFIG_PATH is set
import
os
use_mooncake_distributed_pipe
=
os
.
getenv
(
'MOONCAKE_CONFIG_PATH'
)
is
not
None
if
not
use_mooncake_distributed_pipe
:
raise
ValueError
(
"To use MooncakeConnector, you need to pass the ENV: "
"'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'."
)
else
:
from
vllm.distributed.kv_transfer.kv_pipe.mooncake_pipe
import
(
# noqa: E501
MooncakePipe
)
logger
.
info
(
"Initializing MooncakeConfig under kv_transfer_config %s"
,
self
.
config
)
self
.
lookup_buffer_size
=
self
.
config
.
kv_buffer_size
self
.
producer_buffer
:
Optional
[
SimpleBuffer
]
=
None
self
.
consumer_buffer
:
Optional
[
SimpleBuffer
]
=
None
self
.
producer_data_pipe
:
Union
[
PyNcclPipe
,
MooncakePipe
]
self
.
consumer_data_pipe
:
Union
[
PyNcclPipe
,
MooncakePipe
]
self
.
producer_signal_pipe
:
Union
[
PyNcclPipe
,
MooncakePipe
]
self
.
consumer_signal_pipe
:
Union
[
PyNcclPipe
,
MooncakePipe
]
# 2 pipes for every rank in the world
port_offset_base
=
2
*
rank
# In disaggregated prefill, the prefill vLLM only uses send pipe
# and the decode vLLM only uses recv pipe
if
self
.
config
.
is_kv_producer
:
if
self
.
config
.
kv_connector
==
"PyNcclConnector"
:
self
.
producer_data_pipe
=
PyNcclPipe
(
local_rank
=
local_rank
,
config
=
self
.
config
,
port_offset
=
port_offset_base
,
)
self
.
producer_signal_pipe
=
PyNcclPipe
(
local_rank
=
local_rank
,
config
=
self
.
config
,
port_offset
=
port_offset_base
+
1
,
device
=
"cpu"
,
)
elif
self
.
config
.
kv_connector
==
"MooncakeConnector"
:
self
.
producer_data_pipe
=
MooncakePipe
(
local_rank
=
local_rank
,
config
=
self
.
config
,
)
# We only need to initialize MooncakePipe once
self
.
producer_signal_pipe
=
self
.
producer_data_pipe
self
.
producer_buffer
=
SimpleBuffer
(
self
.
producer_signal_pipe
,
self
.
producer_data_pipe
,
self
.
config
.
kv_buffer_size
)
else
:
# the current vLLM instance is KV consumer, so it needs to connect
# its recv pipe to the send pipe of KV producer
if
self
.
config
.
kv_connector
==
"PyNcclConnector"
:
self
.
consumer_data_pipe
=
PyNcclPipe
(
local_rank
=
local_rank
,
config
=
self
.
config
,
port_offset
=
port_offset_base
,
)
self
.
consumer_signal_pipe
=
PyNcclPipe
(
local_rank
=
local_rank
,
config
=
self
.
config
,
port_offset
=
port_offset_base
+
1
,
device
=
"cpu"
,
)
elif
self
.
config
.
kv_connector
==
"MooncakeConnector"
:
self
.
consumer_data_pipe
=
MooncakePipe
(
local_rank
=
local_rank
,
config
=
self
.
config
,
)
self
.
consumer_signal_pipe
=
self
.
consumer_data_pipe
self
.
consumer_buffer
=
SimpleBuffer
(
self
.
consumer_signal_pipe
,
self
.
consumer_data_pipe
,
self
.
config
.
kv_buffer_size
,
)
def
select
(
self
,
input_tokens
:
Optional
[
torch
.
Tensor
],
roi
:
Optional
[
torch
.
Tensor
])
->
list
[
Optional
[
torch
.
Tensor
]]:
assert
self
.
consumer_buffer
is
not
None
,
"Please initialize the "
\
"consumer buffer before calling select."
return
self
.
consumer_buffer
.
drop_select
(
input_tokens
,
roi
)
def
insert
(
self
,
input_tokens
:
torch
.
Tensor
,
roi
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
hidden
:
torch
.
Tensor
)
->
None
:
assert
self
.
producer_buffer
is
not
None
,
"Please initialize the "
\
"producer buffer before calling insert."
self
.
producer_buffer
.
insert
(
input_tokens
,
roi
,
key
,
value
,
hidden
)
def
send_kv_caches_and_hidden_states
(
self
,
model_executable
:
torch
.
nn
.
Module
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
kv_caches
:
list
[
torch
.
Tensor
],
hidden_or_intermediate_states
:
Union
[
torch
.
Tensor
,
IntermediateTensors
],
)
->
None
:
input_tokens_tensor
=
model_input
.
input_tokens
seq_lens
=
model_input
.
attn_metadata
.
seq_lens
slot_mapping_flat
=
model_input
.
attn_metadata
.
slot_mapping
.
flatten
()
num_prefill_tokens
=
model_input
.
attn_metadata
.
num_prefill_tokens
start_layer
=
model_executable
.
model
.
start_layer
end_layer
=
model_executable
.
model
.
end_layer
num_heads
,
head_size
=
self
.
kv_helper
.
get_model_args
(
model_executable
)
# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
# FIXME(Kuntai): This assume that all requests are prefill.
for
idx
,
slen
in
enumerate
(
seq_lens
):
start_pos
=
sum
(
seq_lens
[:
idx
])
end_pos
=
start_pos
+
slen
if
start_pos
>=
num_prefill_tokens
:
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger
.
warning
(
"You have some decode requests while using "
"SimpleConnector. Their KVCache won't be sent."
)
break
current_tokens
=
input_tokens_tensor
[
start_pos
:
end_pos
]
keys
,
values
=
[],
[]
for
layer_id
in
range
(
start_layer
,
end_layer
):
kv_cache
=
kv_caches
[
layer_id
-
start_layer
]
key_cache
,
value_cache
=
self
.
kv_helper
.
get_kv_from_cache
(
kv_cache
,
num_heads
,
head_size
)
current_slot_mapping
=
slot_mapping_flat
[
start_pos
:
end_pos
]
keys
.
append
(
key_cache
[
current_slot_mapping
].
unsqueeze
(
0
))
values
.
append
(
value_cache
[
current_slot_mapping
].
unsqueeze
(
0
))
keys
=
torch
.
cat
(
keys
,
dim
=
0
)
values
=
torch
.
cat
(
values
,
dim
=
0
)
self
.
insert
(
current_tokens
,
torch
.
ones_like
(
current_tokens
,
dtype
=
bool
),
keys
,
values
,
hidden_or_intermediate_states
[
start_pos
:
end_pos
])
logger
.
debug
(
"[rank%d]: KV send DONE."
,
torch
.
distributed
.
get_rank
())
def
recv_kv_caches_and_hidden_states
(
self
,
model_executable
:
torch
.
nn
.
Module
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
kv_caches
:
list
[
torch
.
Tensor
]
)
->
tuple
[
Union
[
torch
.
Tensor
,
IntermediateTensors
],
bool
,
"ModelInputForGPUWithSamplingMetadata"
]:
# When bypass_model_exec is set to False, it means that at least for one
# request its corresponding KV cache or hidden state is missing.
# In this case we need to do prefilling to recompute missing KV cache
# and hidden states.
bypass_model_exec
=
True
input_tokens_tensor
=
model_input
.
input_tokens
seq_lens
=
model_input
.
attn_metadata
.
seq_lens
num_prefill_tokens
=
model_input
.
attn_metadata
.
num_prefill_tokens
slot_mapping
=
model_input
.
attn_metadata
.
slot_mapping
.
flatten
()
start_layer
=
model_executable
.
model
.
start_layer
end_layer
=
model_executable
.
model
.
end_layer
hidden_or_intermediate_states_for_one_req
=
[]
input_tokens_list
=
[]
num_computed_tokens_list
=
[]
start_pos_list
=
[]
# enumerate different requests
# FIXME(Kuntai): This impl assumes that all requests are prefill.
for
idx
,
slen
in
enumerate
(
seq_lens
):
start_pos
=
sum
(
seq_lens
[:
idx
])
end_pos
=
start_pos
+
slen
if
start_pos
>=
num_prefill_tokens
:
# This can happen during inflight batching. See:
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger
.
warning
(
"You should set --enable_chunked_prefill=False "
"and --max_num_batched_tokens "
"should be equal to --max_seq_len_to_capture"
)
bypass_model_exec
=
False
assert
start_pos
==
num_prefill_tokens
break
current_tokens
=
input_tokens_tensor
[
start_pos
:
end_pos
]
num_tokens
=
slen
# collecting data for rebuilding the input
input_tokens_list
.
append
(
current_tokens
)
start_pos_list
.
append
(
start_pos
)
ret
=
self
.
select
(
current_tokens
,
torch
.
ones_like
(
current_tokens
,
dtype
=
bool
))
if
ret
[
0
]
is
None
:
# didn't find any match.
bypass_model_exec
=
False
num_computed_tokens_list
.
append
(
0
)
continue
roi
:
torch
.
Tensor
=
ret
[
1
]
keys
:
torch
.
Tensor
=
ret
[
2
]
values
:
torch
.
Tensor
=
ret
[
3
]
hidden
:
torch
.
Tensor
=
ret
[
4
]
num_computed_tokens
=
roi
.
shape
[
0
]
num_computed_tokens_list
.
append
(
num_computed_tokens
)
# check if both KV cache and the hidden states are received
# If not, need to redo the forwarding to compute missing states
if
not
all
([(
num_computed_tokens
==
num_tokens
),
hidden
is
not
None
]):
bypass_model_exec
=
False
# update the end position based on how many tokens are cached.
end_pos
=
start_pos
+
num_computed_tokens
# put received KV caches into paged memory
for
cur_layer
in
range
(
start_layer
,
end_layer
):
layer_id
=
cur_layer
-
start_layer
kv_cache
=
kv_caches
[
layer_id
]
layer
=
model_executable
.
model
.
layers
[
cur_layer
]
# get remote kvcache
remote_k
,
remote_v
=
keys
[
layer_id
],
values
[
layer_id
]
self
.
kv_helper
.
put_kv_to_cache
(
model_executable
,
remote_k
,
remote_v
,
layer
,
kv_cache
,
slot_mapping
,
start_pos
,
end_pos
)
hidden_or_intermediate_states_for_one_req
.
append
(
hidden
)
if
not
bypass_model_exec
:
# Some of the KV cache is not retrieved
# Here we will fall back to normal model forwarding
# But optionally you can adjust model_input so that you only do
# prefilling on those tokens that are missing KV caches.
logger
.
warning
(
"[rank%d]: Failed to receive all KVs and hidden "
"states, redo model forwarding."
,
torch
.
distributed
.
get_rank
())
hidden_or_intermediate_states
=
None
else
:
logger
.
debug
(
"[rank%d]: Successfully received all KVs and hidden "
"states, skip model forwarding."
,
torch
.
distributed
.
get_rank
())
hidden_or_intermediate_states
=
torch
.
cat
(
hidden_or_intermediate_states_for_one_req
,
dim
=
0
)
return
hidden_or_intermediate_states
,
bypass_model_exec
,
model_input
def
close
(
self
):
self
.
producer_data_pipe
.
close
()
self
.
consumer_data_pipe
.
close
()
if
self
.
config
.
kv_connector
==
"PyNcclConnector"
:
self
.
producer_signal_pipe
.
close
()
self
.
consumer_signal_pipe
.
close
()
elif
self
.
config
.
kv_connector
==
"MooncakeConnector"
:
# MooncakePipe reuses data_pipe for signal_pipe, so we only have to
# close the data_pipe.
pass
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
f4f4e7ef
...
...
@@ -13,8 +13,8 @@ import torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
VllmConfig
,
get_current_vllm_config
from
vllm.distributed.kv_transfer.kv_connector.
factory
import
(
KVConnector
Factory
)
from
vllm.distributed.kv_transfer.kv_connector.
v1.base
import
(
KVConnector
Base_V1
)
from
vllm.logger
import
init_logger
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
...
...
@@ -106,9 +106,8 @@ def get_kv_connector_cache_layout():
vllm_config
=
get_current_vllm_config
()
kv_config
=
vllm_config
.
kv_transfer_config
if
kv_config
is
not
None
:
connector_cls
=
KVConnectorFactory
.
get_connector_class
(
kv_config
)
required_kvcache_layout
=
connector_cls
.
get_required_kvcache_layout
(
vllm_config
)
required_kvcache_layout
=
(
KVConnectorBase_V1
.
get_required_kvcache_layout
(
vllm_config
))
if
required_kvcache_layout
is
not
None
:
return
required_kvcache_layout
logger
.
info_once
(
"Connectors do not specify a "
\
...
...
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
View file @
f4f4e7ef
...
...
@@ -52,7 +52,7 @@ class MultiConnector(KVConnectorBase_V1):
temp_config
.
kv_transfer_config
=
KVTransferConfig
(
**
ktc
,
engine_id
=
engine_id
)
self
.
_connectors
.
append
(
KVConnectorFactory
.
create_connector
_v1
(
temp_config
,
role
))
KVConnectorFactory
.
create_connector
(
temp_config
,
role
))
# A mapping from request id to the index of the connector chosen to
# load the request from (if any).
...
...
@@ -223,9 +223,9 @@ class MultiConnector(KVConnectorBase_V1):
for
ktc
in
ktcs
:
kv_transfer_config
=
KVTransferConfig
(
**
ktc
)
temp_vllm_config
.
kv_transfer_config
=
kv_transfer_config
required_kvcache_layout
=
KVConnectorFactory
.
get_connector_class
(
kv_transfer_config
)
.
get_required_kvcache_layout
(
temp_vllm_config
)
required_kvcache_layout
=
(
KVConnectorBase_V1
.
get_required_kvcache_layout
(
temp_vllm_config
)
)
if
required_kvcache_layout
is
not
None
:
layouts
.
add
(
required_kvcache_layout
)
...
...
vllm/distributed/kv_transfer/kv_connector_agent.py
deleted
100644 → 0
View file @
5ea71ff4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A centralized entrypoint to perform distributed KV cache transfer.
This implementation is a shim wrapper on two APIs exposed by `kv_connector`:
1. `send_kv_caches_and_hidden_states`
2. `recv_kv_caches_and_hidden_states
"""
from
typing
import
TYPE_CHECKING
,
Union
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
from
vllm.config
import
VllmConfig
import
torch
from
vllm.distributed.kv_transfer.kv_connector.factory
import
(
KVConnectorFactory
)
from
vllm.logger
import
init_logger
from
vllm.sequence
import
IntermediateTensors
logger
=
init_logger
(
__name__
)
class
KVTransferAgent
:
"""
A class designated for distributed KV transfer
Target use cases:
1. Disaggregated prefill
2. Remote KV cache storage
"""
def
__init__
(
self
,
rank
:
int
,
local_rank
:
int
,
config
:
"VllmConfig"
,
):
self
.
config
=
config
if
config
.
kv_transfer_config
is
None
:
raise
ValueError
(
"KVTransferConfig is not set in the VllmConfig,"
" cannot initialize KVConnector."
)
assert
self
.
config
.
kv_transfer_config
.
is_kv_transfer_instance
,
"KV"
\
"TransferAgent should only be used when kv_connector is set."
self
.
connector
=
KVConnectorFactory
.
create_connector_v0
(
rank
,
local_rank
,
config
)
def
send_kv_caches_and_hidden_states
(
self
,
model_executable
:
torch
.
nn
.
Module
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
kv_caches
:
list
[
torch
.
Tensor
],
hidden_or_intermediate_states
:
Union
[
torch
.
Tensor
,
IntermediateTensors
],
)
->
None
:
self
.
connector
.
send_kv_caches_and_hidden_states
(
model_executable
,
model_input
,
kv_caches
,
hidden_or_intermediate_states
)
def
close
(
self
)
->
None
:
self
.
connector
.
close
()
def
recv_kv_caches_and_hidden_states
(
self
,
model_executable
:
torch
.
nn
.
Module
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
kv_caches
:
list
[
torch
.
Tensor
]
)
->
tuple
[
Union
[
torch
.
Tensor
,
IntermediateTensors
],
bool
,
"ModelInputForGPUWithSamplingMetadata"
]:
return
self
.
connector
.
recv_kv_caches_and_hidden_states
(
model_executable
,
model_input
,
kv_caches
)
vllm/distributed/kv_transfer/kv_transfer_state.py
View file @
f4f4e7ef
...
...
@@ -8,7 +8,6 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory
)
from
vllm.distributed.kv_transfer.kv_connector.v1
import
(
KVConnectorBase_V1
,
KVConnectorRole
)
from
vllm.distributed.parallel_state
import
get_world_group
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
...
...
@@ -61,11 +60,7 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
if
(
vllm_config
.
kv_transfer_config
.
is_kv_transfer_instance
and
_KV_CONNECTOR_AGENT
is
None
):
if
envs
.
VLLM_USE_V1
:
_KV_CONNECTOR_AGENT
=
KVConnectorFactory
.
create_connector
_v1
(
_KV_CONNECTOR_AGENT
=
KVConnectorFactory
.
create_connector
(
config
=
vllm_config
,
role
=
KVConnectorRole
.
WORKER
)
else
:
_KV_CONNECTOR_AGENT
=
KVConnectorFactory
.
create_connector_v0
(
rank
=
get_world_group
().
rank
,
local_rank
=
get_world_group
().
local_rank
,
config
=
vllm_config
,
)
raise
ValueError
(
"V0 is no longer supported"
)
vllm/v1/core/sched/scheduler.py
View file @
f4f4e7ef
...
...
@@ -83,7 +83,7 @@ class Scheduler(SchedulerInterface):
assert
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
==
1
,
(
"Multiple KV cache groups are not currently supported "
"with KV connectors"
)
self
.
connector
=
KVConnectorFactory
.
create_connector
_v1
(
self
.
connector
=
KVConnectorFactory
.
create_connector
(
config
=
self
.
vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
)
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
...
...
vllm/v1/worker/kv_connector_model_runner_mixin.py
View file @
f4f4e7ef
...
...
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Optional
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
has_kv_transfer_group
)
from
vllm.distributed.kv_transfer.kv_connector.
v1
import
KVConnectorBase
_V1
from
vllm.distributed.kv_transfer.kv_connector.
base
import
KVConnectorBase
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
KVConnectorOutput
,
...
...
@@ -31,7 +31,7 @@ class KVConnectorModelRunnerMixin:
# Update KVConnector with the KVConnector metadata forward().
if
has_kv_transfer_group
():
kv_connector
=
get_kv_transfer_group
()
assert
isinstance
(
kv_connector
,
KVConnectorBase
_V1
)
assert
isinstance
(
kv_connector
,
KVConnectorBase
)
assert
scheduler_output
.
kv_connector_metadata
is
not
None
kv_connector
.
bind_connector_metadata
(
scheduler_output
.
kv_connector_metadata
)
...
...
@@ -93,7 +93,7 @@ class KVConnectorModelRunnerMixin:
# Update KVConnector with the KVConnector metadata forward().
kv_connector
=
get_kv_transfer_group
()
assert
isinstance
(
kv_connector
,
KVConnectorBase
_V1
)
assert
isinstance
(
kv_connector
,
KVConnectorBase
)
assert
scheduler_output
.
kv_connector_metadata
is
not
None
kv_connector
.
bind_connector_metadata
(
scheduler_output
.
kv_connector_metadata
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment