Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5e83a727
Unverified
Commit
5e83a727
authored
Apr 25, 2025
by
Yihua Cheng
Committed by
GitHub
Apr 26, 2025
Browse files
[v1] [P/D] Adding LMCache KV connector for v1 (#16625)
parent
68af5f6c
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
793 additions
and
0 deletions
+793
-0
examples/lmcache/README.md
examples/lmcache/README.md
+56
-0
examples/lmcache/cpu_offload_lmcache_v0.py
examples/lmcache/cpu_offload_lmcache_v0.py
+0
-0
examples/lmcache/cpu_offload_lmcache_v1.py
examples/lmcache/cpu_offload_lmcache_v1.py
+57
-0
examples/lmcache/disagg_prefill_lmcache_v0.py
examples/lmcache/disagg_prefill_lmcache_v0.py
+0
-0
examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml
...gg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml
+13
-0
examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml
..._prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml
+13
-0
examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh
.../lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh
+136
-0
examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py
.../lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py
+193
-0
examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh
...lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh
+59
-0
examples/lmcache/kv_cache_sharing_lmcache_v1.py
examples/lmcache/kv_cache_sharing_lmcache_v1.py
+130
-0
vllm/distributed/kv_transfer/kv_connector/factory.py
vllm/distributed/kv_transfer/kv_connector/factory.py
+5
-0
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
...tributed/kv_transfer/kv_connector/v1/lmcache_connector.py
+131
-0
No files found.
examples/lmcache/README.md
0 → 100644
View file @
5e83a727
# LMCache Examples
This folder demonstrates how to use LMCache for disaggregated prefilling, CPU offloading and KV cache sharing.
## 1. Disaggregated Prefill in vLLM v1
This example demonstrates how to run LMCache with disaggregated prefill using NIXL on a single node.
### Prerequisites
-
Install
[
LMCache
](
https://github.com/ai-dynamo/lmcache
)
-
Install
[
NIXL
](
https://github.com/ai-dynamo/nixl
)
-
At least 2 GPUs
-
Valid Hugging Face token (HF_TOKEN) for Llama 3.1 8B Instruct.
### Usage
Run
`cd disagg_prefill_lmcache_v1`
to get into
`disagg_prefill_lmcache_v1`
folder, and then run
```
bash
bash disagg_example_nixl.sh
```
to run disaggregated prefill and benchmark the performance.
### Components
#### Server Scripts
-
`disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh`
- Launches individual vLLM servers for prefill/decode, and also launches the proxy server.
-
`disagg_prefill_lmcache_v1/disagg_proxy_server.py`
- FastAPI proxy server that coordinates between prefiller and decoder
-
`disagg_prefill_lmcache_v1/disagg_example_nixl.sh`
- Main script to run the example
#### Configuration
-
`disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml`
- Configuration for prefiller server
-
`disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml`
- Configuration for decoder server
#### Log Files
The main script generates several log files:
-
`prefiller.log`
- Logs from the prefill server
-
`decoder.log`
- Logs from the decode server
-
`proxy.log`
- Logs from the proxy server
## 2. CPU Offload Examples
-
`cpu_offload_lmcache_v0.py`
- CPU offloading implementation for vLLM v0
-
`cpu_offload_lmcache_v1.py`
- CPU offloading implementation for vLLM v1
## 3. KV Cache Sharing
The
`kv_cache_sharing_lmcache_v1.py`
example demonstrates how to share KV caches between vLLM v1 instances.
## 4. Disaggregated Prefill in vLLM v0
The
`disaggregated_prefill_lmcache_v0.py`
provides an example of how to run disaggregated prefill in vLLM v0.
examples/
offline_inferenc
e/cpu_offload_lmcache.py
→
examples/
lmcach
e/cpu_offload_lmcache
_v0
.py
View file @
5e83a727
File moved
examples/lmcache/cpu_offload_lmcache_v1.py
0 → 100644
View file @
5e83a727
# SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates the example usage of cpu offloading
with LMCache in vLLM v1.
Note that lmcache needs to be installed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache.
"""
import
os
from
lmcache.experimental.cache_engine
import
LMCacheEngineBuilder
from
lmcache.integration.vllm.utils
import
ENGINE_NAME
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
KVTransferConfig
# LMCache-related environment variables
# Use experimental features in LMCache
os
.
environ
[
"LMCACHE_USE_EXPERIMENTAL"
]
=
"True"
# LMCache is set to use 256 tokens per chunk
os
.
environ
[
"LMCACHE_CHUNK_SIZE"
]
=
"256"
# Enable local CPU backend in LMCache
os
.
environ
[
"LMCACHE_LOCAL_CPU"
]
=
"True"
# Set local CPU memory limit to 5.0 GB
os
.
environ
[
"LMCACHE_MAX_LOCAL_CPU_SIZE"
]
=
"5.0"
# This example script runs two requests with a shared prefix.
shared_prompt
=
"Hello, how are you?"
*
1000
first_prompt
=
[
shared_prompt
+
"Hello, my name is"
,
]
second_prompt
=
[
shared_prompt
+
"Tell me a very long story"
,
]
sampling_params
=
SamplingParams
(
temperature
=
0
,
top_p
=
0.95
,
max_tokens
=
10
)
ktc
=
KVTransferConfig
.
from_cli
(
'{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}'
)
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
# Note that LMCache is not compatible with chunked prefill for now.
llm
=
LLM
(
model
=
"meta-llama/Meta-Llama-3.1-8B-Instruct"
,
kv_transfer_config
=
ktc
,
max_model_len
=
8000
,
gpu_memory_utilization
=
0.8
)
# Should be able to see logs like the following:
# `LMCache INFO: Storing KV cache for 6006 out of 6006 tokens for request 0`
# This indicates that the KV cache has been stored in LMCache.
outputs
=
llm
.
generate
(
first_prompt
,
sampling_params
)
for
output
in
outputs
:
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Generated text:
{
generated_text
!
r
}
"
)
# Clean up lmcache backend
LMCacheEngineBuilder
.
destroy
(
ENGINE_NAME
)
examples/
offline_inferenc
e/disagg
regated
_prefill_lmcache.py
→
examples/
lmcach
e/disagg_prefill_lmcache
_v0
.py
View file @
5e83a727
File moved
examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml
0 → 100644
View file @
5e83a727
local_cpu
:
False
max_local_cpu_size
:
0
#local_disk:
max_local_disk_size
:
0
remote_serde
:
NULL
enable_nixl
:
True
nixl_role
:
"
receiver"
nixl_peer_host
:
"
localhost"
nixl_peer_port
:
55555
nixl_buffer_size
:
1073741824
# 1GB
nixl_buffer_device
:
"
cuda"
nixl_enable_gc
:
True
examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml
0 → 100644
View file @
5e83a727
local_cpu
:
False
max_local_cpu_size
:
0
#local_disk:
max_local_disk_size
:
0
remote_serde
:
NULL
enable_nixl
:
True
nixl_role
:
"
sender"
nixl_peer_host
:
"
localhost"
nixl_peer_port
:
55555
nixl_buffer_size
:
1073741824
# 1GB
nixl_buffer_device
:
"
cuda"
nixl_enable_gc
:
True
examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh
0 → 100644
View file @
5e83a727
#!/bin/bash
echo
"Warning: LMCache disaggregated prefill support for vLLM v1 is experimental and subject to change."
PIDS
=()
# Switch to the directory of the current script
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
check_hf_token
()
{
if
[
-z
"
$HF_TOKEN
"
]
;
then
echo
"HF_TOKEN is not set. Please set it to your Hugging Face token."
exit
1
fi
if
[[
"
$HF_TOKEN
"
!=
hf_
*
]]
;
then
echo
"HF_TOKEN is not a valid Hugging Face token. Please set it to your Hugging Face token."
exit
1
fi
echo
"HF_TOKEN is set and valid."
}
check_num_gpus
()
{
# can you check if the number of GPUs are >=2 via nvidia-smi?
num_gpus
=
$(
nvidia-smi
--query-gpu
=
name
--format
=
csv,noheader |
wc
-l
)
if
[
"
$num_gpus
"
-lt
2
]
;
then
echo
"You need at least 2 GPUs to run disaggregated prefill."
exit
1
else
echo
"Found
$num_gpus
GPUs."
fi
}
ensure_python_library_installed
()
{
echo
"Checking if
$1
is installed..."
python
-c
"import
$1
"
>
/dev/null 2>&1
if
[
$?
-ne
0
]
;
then
if
[
"
$1
"
==
"nixl"
]
;
then
echo
"
$1
is not installed. Please refer to https://github.com/ai-dynamo/nixl for installation."
else
echo
"
$1
is not installed. Please install it via pip install
$1
."
fi
exit
1
else
echo
"
$1
is installed."
fi
}
cleanup
()
{
echo
"Stopping everything…"
trap
- INT TERM
# prevent re-entrancy
kill
--
-
$$
# negative PID == “this whole process-group”
wait
# reap children so we don't leave zombies
exit
0
}
wait_for_server
()
{
local
port
=
$1
local
timeout_seconds
=
1200
local
start_time
=
$(
date
+%s
)
echo
"Waiting for server on port
$port
..."
while
true
;
do
if
curl
-s
"localhost:
${
port
}
/v1/completions"
>
/dev/null
;
then
return
0
fi
local
now
=
$(
date
+%s
)
if
((
now - start_time
>=
timeout_seconds
))
;
then
echo
"Timeout waiting for server"
return
1
fi
sleep
1
done
}
main
()
{
check_hf_token
check_num_gpus
ensure_python_library_installed lmcache
ensure_python_library_installed nixl
ensure_python_library_installed pandas
ensure_python_library_installed datasets
ensure_python_library_installed vllm
trap
cleanup INT
trap
cleanup USR1
trap
cleanup TERM
echo
"Launching prefiller, decoder and proxy..."
echo
"Please check prefiller.log, decoder.log and proxy.log for logs."
bash disagg_vllm_launcher.sh prefiller
\
>
>(
tee
prefiller.log
)
2>&1 &
prefiller_pid
=
$!
PIDS+
=(
$prefiller_pid
)
bash disagg_vllm_launcher.sh decoder
\
>
>(
tee
decoder.log
)
2>&1 &
decoder_pid
=
$!
PIDS+
=(
$decoder_pid
)
python3 disagg_proxy_server.py
\
--host
localhost
\
--port
9000
\
--prefiller-host
localhost
\
--prefiller-port
8100
\
--decoder-host
localhost
\
--decoder-port
8200
\
>
>(
tee
proxy.log
)
2>&1 &
proxy_pid
=
$!
PIDS+
=(
$proxy_pid
)
wait_for_server 8100
wait_for_server 8200
wait_for_server 9000
echo
"All servers are up. Starting benchmark..."
# begin benchmark
cd
../../../benchmarks/
python benchmark_serving.py
--port
9000
--seed
$(
date
+%s
)
\
--model
meta-llama/Llama-3.1-8B-Instruct
\
--dataset-name
random
--random-input-len
7500
--random-output-len
200
\
--num-prompts
200
--burstiness
100
--request-rate
3.6 |
tee
benchmark.log
echo
"Benchmarking done. Cleaning up..."
cleanup
}
main
\ No newline at end of file
examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py
0 → 100644
View file @
5e83a727
# SPDX-License-Identifier: Apache-2.0
import
argparse
import
os
import
time
from
contextlib
import
asynccontextmanager
import
httpx
import
numpy
as
np
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
StreamingResponse
@
asynccontextmanager
async
def
lifespan
(
app
:
FastAPI
):
"""
Lifespan context manager to handle startup and shutdown events.
"""
# Startup: Initialize clients
prefiller_base_url
=
f
'http://
{
global_args
.
prefiller_host
}
:
{
global_args
.
prefiller_port
}
/v1'
decoder_base_url
=
f
'http://
{
global_args
.
decoder_host
}
:
{
global_args
.
decoder_port
}
/v1'
app
.
state
.
prefill_client
=
httpx
.
AsyncClient
(
timeout
=
None
,
base_url
=
prefiller_base_url
)
app
.
state
.
decode_client
=
httpx
.
AsyncClient
(
timeout
=
None
,
base_url
=
decoder_base_url
)
yield
# Shutdown: Close clients
await
app
.
state
.
prefill_client
.
aclose
()
await
app
.
state
.
decode_client
.
aclose
()
# Update FastAPI app initialization to use lifespan
app
=
FastAPI
(
lifespan
=
lifespan
)
class
StatsCalculator
:
def
__init__
(
self
):
self
.
_stats
=
[]
self
.
_last_log_time
=
time
.
time
()
def
add
(
self
,
value
):
self
.
_stats
.
append
(
value
)
if
time
.
time
()
-
self
.
_last_log_time
>
5
:
self
.
_log_stats
()
self
.
_last_log_time
=
time
.
time
()
def
_log_stats
(
self
):
# Print average, median, and 99th percentile
np_arr
=
np
.
array
(
self
.
_stats
)
output_str
=
f
"
\n
Num requests:
{
len
(
self
.
_stats
)
}
"
+
\
"
\n
Prefill node TTFT stats:"
+
\
f
"
\n
- Average (ms):
{
np
.
mean
(
np_arr
)
}
"
+
\
f
"
\n
- Median (ms):
{
np
.
median
(
np_arr
)
}
"
+
\
f
"
\n
- 99th Percentile (ms):
{
np
.
percentile
(
np_arr
,
99
)
}
\n
"
print
(
"==============================="
,
output_str
,
"==============================="
)
stats_calculator
=
StatsCalculator
()
counter
=
0
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--prefiller-host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--prefiller-port"
,
type
=
int
,
default
=
8100
)
parser
.
add_argument
(
"--decoder-host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--decoder-port"
,
type
=
int
,
default
=
8200
)
args
=
parser
.
parse_args
()
return
args
# Initialize variables to hold the persistent clients
app
.
state
.
prefill_client
=
None
app
.
state
.
decode_client
=
None
async
def
send_request_to_service
(
client
:
httpx
.
AsyncClient
,
endpoint
:
str
,
req_data
:
dict
):
"""
Send a request to a service using a persistent client.
"""
req_data
=
req_data
.
copy
()
req_data
[
'max_tokens'
]
=
1
if
'max_completion_tokens'
in
req_data
:
req_data
[
'max_completion_tokens'
]
=
1
headers
=
{
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
}
response
=
await
client
.
post
(
endpoint
,
json
=
req_data
,
headers
=
headers
)
response
.
raise_for_status
()
return
response
async
def
stream_service_response
(
client
:
httpx
.
AsyncClient
,
endpoint
:
str
,
req_data
:
dict
):
"""
Asynchronously stream the response from a service using a persistent client.
"""
headers
=
{
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
}
async
with
client
.
stream
(
"POST"
,
endpoint
,
json
=
req_data
,
headers
=
headers
)
as
response
:
response
.
raise_for_status
()
async
for
chunk
in
response
.
aiter_bytes
():
yield
chunk
@
app
.
post
(
"/v1/completions"
)
async
def
handle_completions
(
request
:
Request
):
global
counter
,
stats_calculator
counter
+=
1
st
=
time
.
time
()
try
:
req_data
=
await
request
.
json
()
# Send request to prefill service, ignore the response
await
send_request_to_service
(
app
.
state
.
prefill_client
,
"/completions"
,
req_data
)
et
=
time
.
time
()
stats_calculator
.
add
(
et
-
st
)
# Stream response from decode service
async
def
generate_stream
():
async
for
chunk
in
stream_service_response
(
app
.
state
.
decode_client
,
"/completions"
,
req_data
):
yield
chunk
return
StreamingResponse
(
generate_stream
(),
media_type
=
"application/json"
)
except
Exception
as
e
:
import
sys
import
traceback
exc_info
=
sys
.
exc_info
()
print
(
"Error occurred in disagg prefill proxy server"
" - completions endpoint"
)
print
(
e
)
print
(
""
.
join
(
traceback
.
format_exception
(
*
exc_info
)))
raise
@
app
.
post
(
"/v1/chat/completions"
)
async
def
handle_chat_completions
(
request
:
Request
):
global
counter
,
stats_calculator
counter
+=
1
st
=
time
.
time
()
try
:
req_data
=
await
request
.
json
()
# Send request to prefill service, ignore the response
await
send_request_to_service
(
app
.
state
.
prefill_client
,
"/chat/completions"
,
req_data
)
et
=
time
.
time
()
stats_calculator
.
add
(
et
-
st
)
# Stream response from decode service
async
def
generate_stream
():
async
for
chunk
in
stream_service_response
(
app
.
state
.
decode_client
,
"/chat/completions"
,
req_data
):
yield
chunk
return
StreamingResponse
(
generate_stream
(),
media_type
=
"application/json"
)
except
Exception
as
e
:
import
sys
import
traceback
exc_info
=
sys
.
exc_info
()
print
(
"Error occurred in disagg prefill proxy server "
" - chat completions endpoint"
)
print
(
e
)
print
(
""
.
join
(
traceback
.
format_exception
(
*
exc_info
)))
raise
if
__name__
==
'__main__'
:
global
global_args
global_args
=
parse_args
()
import
uvicorn
uvicorn
.
run
(
app
,
host
=
global_args
.
host
,
port
=
global_args
.
port
)
examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh
0 → 100644
View file @
5e83a727
#!/bin/bash
SCRIPT_DIR
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
&&
pwd
)
"
if
[[
$#
-lt
1
]]
;
then
echo
"Usage:
$0
<prefiller | decoder> [model]"
exit
1
fi
if
[[
$#
-eq
1
]]
;
then
echo
"Using default model: meta-llama/Llama-3.1-8B-Instruct"
MODEL
=
"meta-llama/Llama-3.1-8B-Instruct"
else
echo
"Using model:
$2
"
MODEL
=
$2
fi
if
[[
$1
==
"prefiller"
]]
;
then
# Prefiller listens on port 8100
prefill_config_file
=
$SCRIPT_DIR
/configs/lmcache-prefiller-config.yaml
UCX_TLS
=
cuda_ipc,cuda_copy,tcp
\
LMCACHE_CONFIG_FILE
=
$prefill_config_file
\
LMCACHE_USE_EXPERIMENTAL
=
True
\
VLLM_ENABLE_V1_MULTIPROCESSING
=
1
\
VLLM_WORKER_MULTIPROC_METHOD
=
spawn
\
CUDA_VISIBLE_DEVICES
=
0
\
vllm serve
$MODEL
\
--port
8100
\
--disable-log-requests
\
--enforce-eager
\
--kv-transfer-config
\
'{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}'
elif
[[
$1
==
"decoder"
]]
;
then
# Decoder listens on port 8200
decode_config_file
=
$SCRIPT_DIR
/configs/lmcache-decoder-config.yaml
UCX_TLS
=
cuda_ipc,cuda_copy,tcp
\
LMCACHE_CONFIG_FILE
=
$decode_config_file
\
LMCACHE_USE_EXPERIMENTAL
=
True
\
VLLM_ENABLE_V1_MULTIPROCESSING
=
1
\
VLLM_WORKER_MULTIPROC_METHOD
=
spawn
\
CUDA_VISIBLE_DEVICES
=
1
\
vllm serve
$MODEL
\
--port
8200
\
--disable-log-requests
\
--enforce-eager
\
--kv-transfer-config
\
'{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}'
else
echo
"Invalid role:
$1
"
echo
"Should be either prefill, decode"
exit
1
fi
examples/lmcache/kv_cache_sharing_lmcache_v1.py
0 → 100644
View file @
5e83a727
# SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates the example usage of remote KV cache sharing
with LMCache.
We will launch 2 vllm instances, and launch an additional LMCache server.
KV cache is transferred in the following manner:
(1) vLLM instance 1 -> LMCache server (KV cache store).
(2) LMCache server -> vLLM instance 2 (KV cache reuse/retrieve).
Note that lmcache needs to be installed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache.
"""
import
os
import
subprocess
import
time
from
multiprocessing
import
Event
,
Process
from
lmcache.experimental.cache_engine
import
LMCacheEngineBuilder
from
lmcache.integration.vllm.utils
import
ENGINE_NAME
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
KVTransferConfig
# LMCache-related environment variables
# The port to start LMCache server
port
=
8100
# Use experimental features in LMCache
os
.
environ
[
"LMCACHE_USE_EXPERIMENTAL"
]
=
"True"
# LMCache is set to use 256 tokens per chunk
os
.
environ
[
"LMCACHE_CHUNK_SIZE"
]
=
"256"
# Disable local CPU backend in LMCache
os
.
environ
[
"LMCACHE_LOCAL_CPU"
]
=
"False"
# Set local CPU memory buffer limit to 5.0 GB
os
.
environ
[
"LMCACHE_MAX_LOCAL_CPU_SIZE"
]
=
"5.0"
# Set the remote URL for LMCache server
os
.
environ
[
"LMCACHE_REMOTE_URL"
]
=
f
"lm://localhost:
{
port
}
"
# Set the serializer/deserializer between vllm and LMCache server
# `naive` indicates using raw bytes of the tensor without any compression
os
.
environ
[
"LMCACHE_REMOTE_SERDE"
]
=
"naive"
prompts
=
[
"Hello, how are you?"
*
1000
,
]
def
run_store
(
store_done
,
prompts
):
# We use GPU 0 for KV cache store process.
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"0"
sampling_params
=
SamplingParams
(
temperature
=
0
,
top_p
=
0.95
,
max_tokens
=
10
)
ktc
=
KVTransferConfig
.
from_cli
(
'{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}'
)
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
llm
=
LLM
(
model
=
"mistralai/Mistral-7B-Instruct-v0.2"
,
kv_transfer_config
=
ktc
,
max_model_len
=
8000
,
gpu_memory_utilization
=
0.8
,
enforce_eager
=
True
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
for
output
in
outputs
:
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Generated text:
{
generated_text
!
r
}
"
)
print
(
"KV cache store is finished."
)
store_done
.
set
()
# Clean up lmcache backend
LMCacheEngineBuilder
.
destroy
(
ENGINE_NAME
)
def
run_retrieve
(
store_done
,
prompts
,
timeout
=
1
):
# We use GPU 1 for KV cache retrieve process.
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"1"
sampling_params
=
SamplingParams
(
temperature
=
0
,
top_p
=
0.95
,
max_tokens
=
10
)
ktc
=
KVTransferConfig
.
from_cli
(
'{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}'
)
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# of memory. Reduce the value if your GPU has less memory.
llm
=
LLM
(
model
=
"mistralai/Mistral-7B-Instruct-v0.2"
,
kv_transfer_config
=
ktc
,
max_model_len
=
8000
,
gpu_memory_utilization
=
0.8
,
enforce_eager
=
True
)
print
(
"Waiting for KV cache store to finish..."
)
store_done
.
wait
()
time
.
sleep
(
timeout
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
for
output
in
outputs
:
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Generated text:
{
generated_text
!
r
}
"
)
# Clean up lmcache backend
LMCacheEngineBuilder
.
destroy
(
ENGINE_NAME
)
def
run_lmcache_server
(
port
):
server_proc
=
subprocess
.
Popen
([
"python"
,
"-m"
,
"lmcache.experimental.server"
,
"localhost"
,
str
(
port
)
])
return
server_proc
def
main
():
store_done
=
Event
()
store_process
=
Process
(
target
=
run_store
,
args
=
(
store_done
,
prompts
))
retrieve_process
=
Process
(
target
=
run_retrieve
,
args
=
(
store_done
,
prompts
))
lmcache_server_process
=
run_lmcache_server
(
port
)
# Start KV cache store process
store_process
.
start
()
# Start KV cache retrieve process
retrieve_process
.
start
()
# Clean up the processes
store_process
.
join
()
retrieve_process
.
terminate
()
lmcache_server_process
.
terminate
()
lmcache_server_process
.
wait
()
if
__name__
==
"__main__"
:
main
()
vllm/distributed/kv_transfer/kv_connector/factory.py
View file @
5e83a727
...
...
@@ -100,3 +100,8 @@ KVConnectorFactory.register_connector(
"SharedStorageConnector"
,
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector"
,
"SharedStorageConnector"
)
KVConnectorFactory
.
register_connector
(
"LMCacheConnectorV1"
,
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector"
,
"LMCacheConnectorV1"
)
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
0 → 100644
View file @
5e83a727
# SPDX-License-Identifier: Apache-2.0
from
typing
import
TYPE_CHECKING
import
torch
from
lmcache.integration.vllm.vllm_v1_adapter
import
LMCacheConnectorV1Impl
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
from
vllm.logger
import
init_logger
from
vllm.v1.core.sched.output
import
SchedulerOutput
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
class
LMCacheConnectorV1
(
KVConnectorBase_V1
):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
)
self
.
_lmcache_engine
=
LMCacheConnectorV1Impl
(
vllm_config
,
role
,
self
)
# ==============================
# Worker-side methods
# ==============================
def
start_load_kv
(
self
,
forward_context
:
"ForwardContext"
,
**
kwargs
)
->
None
:
"""
Start loading the KV cache from the connector to vLLM's paged
KV buffer. This is called from the forward context before the
forward pass to enable async loading during model execution.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
self
.
_lmcache_engine
.
start_load_kv
(
forward_context
,
**
kwargs
)
def
wait_for_layer_load
(
self
,
layer_name
:
str
)
->
None
:
"""
Block until the KV for a specific layer is loaded into vLLM's
paged buffer. This is called from within attention layer to ensure
async copying from start_load_kv is complete.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
self
.
_lmcache_engine
.
wait_for_layer_load
(
layer_name
)
def
save_kv_layer
(
self
,
layer_name
:
str
,
kv_layer
:
torch
.
Tensor
,
attn_metadata
:
"AttentionMetadata"
,
**
kwargs
)
->
None
:
"""
Start saving the a layer of KV cache from vLLM's paged buffer
to the connector. This is called from within attention layer to
enable async copying during execution.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
self
.
_lmcache_engine
.
save_kv_layer
(
layer_name
,
kv_layer
,
attn_metadata
,
**
kwargs
)
def
wait_for_save
(
self
):
"""
Block until all the save operations is done. This is called
as the forward context exits to ensure that the async saving
from save_kv_layer is complete before finishing the forward.
This prevents overwrites of paged KV buffer before saving done.
"""
self
.
_lmcache_engine
.
wait_for_save
()
# ==============================
# Scheduler-side methods
# ==============================
def
get_num_new_matched_tokens
(
self
,
request
:
"Request"
,
num_computed_tokens
:
int
,
)
->
int
:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
return
self
.
_lmcache_engine
.
get_num_new_matched_tokens
(
request
,
num_computed_tokens
)
def
update_state_after_alloc
(
self
,
request
:
"Request"
,
num_external_tokens
:
int
):
"""
Update KVConnector state after block allocation.
"""
self
.
_lmcache_engine
.
update_state_after_alloc
(
request
,
num_external_tokens
)
def
build_connector_meta
(
self
,
scheduler_output
:
SchedulerOutput
)
->
KVConnectorMetadata
:
"""
Build the connector metadata for this step.
This function should NOT modify fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
return
self
.
_lmcache_engine
.
build_connector_meta
(
scheduler_output
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment