Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2f42a488
Unverified
Commit
2f42a488
authored
Feb 25, 2025
by
Jiayi Yao
Committed by
GitHub
Feb 25, 2025
Browse files
[Feature] Support KV cache offloading and disagg prefill with LMCache connector. (#12953)
parent
3173c3b3
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
310 additions
and
2 deletions
+310
-2
examples/offline_inference/cpu_offload_lmcache.py
examples/offline_inference/cpu_offload_lmcache.py
+65
-0
examples/offline_inference/disaggregated_prefill_lmcache.py
examples/offline_inference/disaggregated_prefill_lmcache.py
+130
-0
vllm/distributed/kv_transfer/kv_connector/factory.py
vllm/distributed/kv_transfer/kv_connector/factory.py
+5
-0
vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py
...distributed/kv_transfer/kv_connector/lmcache_connector.py
+108
-0
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+2
-2
No files found.
examples/offline_inference/cpu_offload_lmcache.py
0 → 100644
View file @
2f42a488
# SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates the example usage of cpu offloading
with LMCache.
Note that `pip install lmcache` is needed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache.
"""
import
os
import
time
from
lmcache.experimental.cache_engine
import
LMCacheEngineBuilder
from
lmcache.integration.vllm.utils
import
ENGINE_NAME
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
KVTransferConfig
# LMCache-related environment variables
# Use experimental features in LMCache
os
.
environ
[
"LMCACHE_USE_EXPERIMENTAL"
]
=
"True"
# LMCache is set to use 256 tokens per chunk
os
.
environ
[
"LMCACHE_CHUNK_SIZE"
]
=
"256"
# Enable local CPU backend in LMCache
os
.
environ
[
"LMCACHE_LOCAL_CPU"
]
=
"True"
# Set local CPU memory limit to 5.0 GB
os
.
environ
[
"LMCACHE_MAX_LOCAL_CPU_SIZE"
]
=
"5.0"
# This example script runs two requests with a shared prefix.
shared_prompt
=
"Hello, how are you?"
*
1000
first_prompt
=
[
shared_prompt
+
"Hello, my name is"
,
]
second_prompt
=
[
shared_prompt
+
"Tell me a very long story"
,
]
sampling_params
=
SamplingParams
(
temperature
=
0
,
top_p
=
0.95
,
max_tokens
=
10
)
ktc
=
KVTransferConfig
.
from_cli
(
'{"kv_connector":"LMCacheConnector", "kv_role":"kv_both"}'
)
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
# Note that LMCache is not compatible with chunked prefill for now.
llm
=
LLM
(
model
=
"mistralai/Mistral-7B-Instruct-v0.2"
,
kv_transfer_config
=
ktc
,
max_model_len
=
8000
,
enable_chunked_prefill
=
False
,
gpu_memory_utilization
=
0.8
)
outputs
=
llm
.
generate
(
first_prompt
,
sampling_params
)
for
output
in
outputs
:
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Generated text:
{
generated_text
!
r
}
"
)
print
(
"First request done."
)
time
.
sleep
(
1
)
outputs
=
llm
.
generate
(
second_prompt
,
sampling_params
)
for
output
in
outputs
:
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Generated text:
{
generated_text
!
r
}
"
)
print
(
"Second request done."
)
# Clean up lmcache backend
LMCacheEngineBuilder
.
destroy
(
ENGINE_NAME
)
examples/offline_inference/disaggregated_prefill_lmcache.py
0 → 100644
View file @
2f42a488
# SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates the example usage of disaggregated prefilling
with LMCache.
We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode),
and launch an additional LMCache server.
KV cache is transferred in the following manner:
VLLM prefill node -> LMCache server -> VLLM decode node.
Note that `pip install lmcache` is needed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache.
"""
import
os
import
subprocess
import
time
from
multiprocessing
import
Event
,
Process
from
lmcache.experimental.cache_engine
import
LMCacheEngineBuilder
from
lmcache.integration.vllm.utils
import
ENGINE_NAME
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
KVTransferConfig
# LMCache-related environment variables
# The port to start LMCache server
port
=
8100
# Use experimental features in LMCache
os
.
environ
[
"LMCACHE_USE_EXPERIMENTAL"
]
=
"True"
# LMCache is set to use 256 tokens per chunk
os
.
environ
[
"LMCACHE_CHUNK_SIZE"
]
=
"256"
# Disable local CPU backend in LMCache
os
.
environ
[
"LMCACHE_LOCAL_CPU"
]
=
"False"
# Set local CPU memory buffer limit to 5.0 GB
os
.
environ
[
"LMCACHE_MAX_LOCAL_CPU_SIZE"
]
=
"5.0"
# Set the remote URL for LMCache server
os
.
environ
[
"LMCACHE_REMOTE_URL"
]
=
f
"lm://localhost:
{
port
}
"
# Set the serializer/deserializer between vllm and LMCache server
# `naive` indicates using raw bytes of the tensor without any compression
os
.
environ
[
"LMCACHE_REMOTE_SERDE"
]
=
"naive"
def
run_prefill
(
prefill_done
,
prompts
):
# We use GPU 0 for prefill node.
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"0"
sampling_params
=
SamplingParams
(
temperature
=
0
,
top_p
=
0.95
,
max_tokens
=
1
)
ktc
=
KVTransferConfig
.
from_cli
(
'{"kv_connector":"LMCacheConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}'
)
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
llm
=
LLM
(
model
=
"mistralai/Mistral-7B-Instruct-v0.2"
,
kv_transfer_config
=
ktc
,
max_model_len
=
8000
,
gpu_memory_utilization
=
0.8
,
enforce_eager
=
True
)
#llm.generate(prompts, sampling_params)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
for
output
in
outputs
:
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Generated text:
{
generated_text
!
r
}
"
)
print
(
"Prefill node is finished."
)
prefill_done
.
set
()
# Clean up lmcache backend
LMCacheEngineBuilder
.
destroy
(
ENGINE_NAME
)
def
run_decode
(
prefill_done
,
prompts
,
timeout
=
1
):
# We use GPU 1 for decode node.
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"1"
sampling_params
=
SamplingParams
(
temperature
=
0
,
top_p
=
0.95
,
max_tokens
=
10
)
ktc
=
KVTransferConfig
.
from_cli
(
'{"kv_connector":"LMCacheConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}'
)
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# of memory. Reduce the value if your GPU has less memory.
llm
=
LLM
(
model
=
"mistralai/Mistral-7B-Instruct-v0.2"
,
kv_transfer_config
=
ktc
,
max_model_len
=
8000
,
gpu_memory_utilization
=
0.8
,
enforce_eager
=
True
)
print
(
"Waiting for prefill node to finish..."
)
prefill_done
.
wait
()
time
.
sleep
(
timeout
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
for
output
in
outputs
:
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Generated text:
{
generated_text
!
r
}
"
)
# Clean up lmcache backend
LMCacheEngineBuilder
.
destroy
(
ENGINE_NAME
)
def
run_lmcache_server
(
port
):
server_proc
=
subprocess
.
Popen
([
"python"
,
"-m"
,
"lmcache.experimental.server"
,
"localhost"
,
str
(
port
)
])
return
server_proc
if
__name__
==
"__main__"
:
prompts
=
[
"Hello, how are you?"
*
1000
,
]
prefill_done
=
Event
()
prefill_process
=
Process
(
target
=
run_prefill
,
args
=
(
prefill_done
,
prompts
))
decode_process
=
Process
(
target
=
run_decode
,
args
=
(
prefill_done
,
prompts
))
lmcache_server_process
=
run_lmcache_server
(
port
)
# Start prefill node
prefill_process
.
start
()
# Start decode node
decode_process
.
start
()
# Clean up the processes
decode_process
.
join
()
prefill_process
.
terminate
()
lmcache_server_process
.
terminate
()
lmcache_server_process
.
wait
()
vllm/distributed/kv_transfer/kv_connector/factory.py
View file @
2f42a488
...
@@ -48,3 +48,8 @@ KVConnectorFactory.register_connector(
...
@@ -48,3 +48,8 @@ KVConnectorFactory.register_connector(
"MooncakeConnector"
,
"MooncakeConnector"
,
"vllm.distributed.kv_transfer.kv_connector.simple_connector"
,
"vllm.distributed.kv_transfer.kv_connector.simple_connector"
,
"SimpleConnector"
)
"SimpleConnector"
)
KVConnectorFactory
.
register_connector
(
"LMCacheConnector"
,
"vllm.distributed.kv_transfer.kv_connector.lmcache_connector"
,
"LMCacheConnector"
)
vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py
0 → 100644
View file @
2f42a488
# SPDX-License-Identifier: Apache-2.0
"""
LMCache KV Cache Connector for Distributed Machine Learning Inference
The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker
(KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache;
(2) offload and share KV caches.
"""
from
typing
import
TYPE_CHECKING
,
List
,
Tuple
,
Union
import
torch
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.logger
import
init_logger
from
vllm.sequence
import
IntermediateTensors
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
logger
=
init_logger
(
__name__
)
class
LMCacheConnector
(
KVConnectorBase
):
def
__init__
(
self
,
rank
:
int
,
local_rank
:
int
,
config
:
VllmConfig
,
):
self
.
transfer_config
=
config
.
kv_transfer_config
self
.
vllm_config
=
config
from
lmcache.experimental.cache_engine
import
LMCacheEngineBuilder
from
lmcache.integration.vllm.utils
import
ENGINE_NAME
from
lmcache.integration.vllm.vllm_adapter
import
(
RetrieveStatus
,
StoreStatus
,
init_lmcache_engine
,
lmcache_retrieve_kv
,
lmcache_should_store
,
lmcache_store_kv
)
logger
.
info
(
"Initializing LMCacheConfig under kv_transfer_config %s"
,
self
.
transfer_config
)
# TODO (Jiayi): Find model_config, parallel_config, and cache_config
self
.
engine
=
init_lmcache_engine
(
config
.
model_config
,
config
.
parallel_config
,
config
.
cache_config
)
self
.
lmcache_engine_name
=
ENGINE_NAME
self
.
lmcache_engine_builder
=
LMCacheEngineBuilder
self
.
model_config
=
config
.
model_config
self
.
parallel_config
=
config
.
parallel_config
self
.
cache_config
=
config
.
cache_config
self
.
lmcache_retrieve_kv
=
lmcache_retrieve_kv
self
.
lmcache_store_kv
=
lmcache_store_kv
self
.
lmcache_should_store
=
lmcache_should_store
self
.
store_status
=
StoreStatus
self
.
retrieve_status
=
RetrieveStatus
def
recv_kv_caches_and_hidden_states
(
self
,
model_executable
:
torch
.
nn
.
Module
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
kv_caches
:
List
[
torch
.
Tensor
]
)
->
Tuple
[
Union
[
torch
.
Tensor
,
IntermediateTensors
],
bool
,
"ModelInputForGPUWithSamplingMetadata"
]:
hidden_or_intermediate_states
=
None
# TODO (Jiayi): Need to support chunked prefill
retrieve_status
=
self
.
retrieve_status
.
PREFILL
model_input
,
bypass_model_exec
=
self
.
lmcache_retrieve_kv
(
model_executable
,
model_input
,
self
.
cache_config
,
kv_caches
,
retrieve_status
)
return
hidden_or_intermediate_states
,
bypass_model_exec
,
model_input
def
send_kv_caches_and_hidden_states
(
self
,
model_executable
:
torch
.
nn
.
Module
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
kv_caches
:
List
[
torch
.
Tensor
],
hidden_or_intermediate_states
:
Union
[
torch
.
Tensor
,
IntermediateTensors
],
)
->
None
:
num_reqs
=
0
seq_group_list
=
model_input
.
sampling_metadata
.
seq_groups
assert
seq_group_list
is
not
None
for
seq_group
in
seq_group_list
:
seq_ids
=
seq_group
.
seq_ids
for
seq_id
in
seq_ids
:
num_reqs
+=
1
# TODO (Jiayi): Only normal prefill is supported for now
store_status
=
self
.
lmcache_should_store
(
model_input
)
self
.
lmcache_store_kv
(
self
.
model_config
,
self
.
parallel_config
,
self
.
cache_config
,
model_executable
,
model_input
,
kv_caches
,
store_status
,
)
def
close
(
self
):
self
.
lmcache_engine_builder
.
destroy
(
self
.
lmcache_engine_name
)
vllm/distributed/parallel_state.py
View file @
2f42a488
...
@@ -962,8 +962,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
...
@@ -962,8 +962,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
return
return
if
all
([
if
all
([
vllm_config
.
kv_transfer_config
.
need_kv_parallel_group
,
_KV_TRANSFER
vllm_config
.
kv_transfer_config
.
is_kv_transfer_instance
,
is
None
_KV_TRANSFER
is
None
]):
]):
_KV_TRANSFER
=
kv_transfer
.
KVTransferAgent
(
_KV_TRANSFER
=
kv_transfer
.
KVTransferAgent
(
rank
=
get_world_group
().
rank
,
rank
=
get_world_group
().
rank
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment