Unverified Commit 0590ec3f authored by Kuntai Du's avatar Kuntai Du Committed by GitHub
Browse files

[Core] Implement disagg prefill by StatelessProcessGroup (#10502)



This PR provides initial support for single-node disaggregated prefill in 1P1D scenario.
Signed-off-by: default avatarKuntaiDu <kuntai@uchicago.edu>
Co-authored-by: default avatarApostaC <yihua98@uchicago.edu>
Co-authored-by: default avatarYaoJiayi <120040070@link.cuhk.edu.cn>
parent c11f1721
...@@ -430,6 +430,9 @@ steps: ...@@ -430,6 +430,9 @@ steps:
- vllm/model_executor/models/ - vllm/model_executor/models/
- tests/distributed/ - tests/distributed/
- vllm/compilation - vllm/compilation
- vllm/worker/worker_base.py
- vllm/worker/worker.py
- vllm/worker/model_runner.py
commands: commands:
- pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py - pytest -v -s ./compile/test_wrapper.py
...@@ -443,6 +446,7 @@ steps: ...@@ -443,6 +446,7 @@ steps:
- pip install -e ./plugins/vllm_add_dummy_model - pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py - pytest -v -s distributed/test_distributed_oot.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py
- label: Multi-step Tests (4 GPUs) # 36min - label: Multi-step Tests (4 GPUs) # 36min
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
......
#!/bin/bash
# benchmark the overhead of disaggregated prefill.
# methodology:
# - send all request to prefill vLLM instance. It will buffer KV cache.
# - then send all request to decode instance.
# - The TTFT of decode instance is the overhead.
set -ex
kill_gpu_processes() {
# kill all processes on GPU.
pkill -f pt_main_thread
sleep 10
# remove vllm config file
rm -rf ~/.config/vllm
# Print the GPU memory usage
# so that we know if all GPU processes are killed.
gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0)
# The memory usage should be 0 MB.
echo "GPU 0 Memory Usage: $gpu_memory_usage MB"
}
wait_for_server() {
# wait for vllm server to start
# return 1 if vllm server crashes
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
benchmark() {
export VLLM_LOGGING_LEVEL=DEBUG
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
# compare chunked prefill with disaggregated prefill
results_folder="./results"
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
dataset_name="sonnet"
dataset_path="../sonnet_4x.txt"
num_prompts=10
qps=$1
prefix_len=50
input_len=2048
output_len=$2
CUDA_VISIBLE_DEVICES=0 python3 \
-m vllm.entrypoints.openai.api_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 8100 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
CUDA_VISIBLE_DEVICES=1 python3 \
-m vllm.entrypoints.openai.api_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 8200 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
wait_for_server 8100
wait_for_server 8200
# let the prefill instance finish prefill
python3 ../benchmark_serving.py \
--backend vllm \
--model $model \
--dataset-name $dataset_name \
--dataset-path $dataset_path \
--sonnet-input-len $input_len \
--sonnet-output-len "$output_len" \
--sonnet-prefix-len $prefix_len \
--num-prompts $num_prompts \
--port 8100 \
--save-result \
--result-dir $results_folder \
--result-filename disagg_prefill_2xtp4.json \
--request-rate "inf"
# send the request to decode.
# The TTFT of this command will be the overhead of disagg prefill impl.
python3 ../benchmark_serving.py \
--backend vllm \
--model $model \
--dataset-name $dataset_name \
--dataset-path $dataset_path \
--sonnet-input-len $input_len \
--sonnet-output-len "$output_len" \
--sonnet-prefix-len $prefix_len \
--num-prompts $num_prompts \
--port 8200 \
--save-result \
--result-dir $results_folder \
--result-filename disagg_prefill_2xtp4.json \
--request-rate "$qps"
kill_gpu_processes
}
main() {
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
(which jq) || (apt-get -y install jq)
(which socat) || (apt-get -y install socat)
pip install quart httpx
cd "$(dirname "$0")"
cd ..
# create sonnet-4x.txt
echo "" > sonnet_4x.txt
for _ in {1..4}
do
cat sonnet.txt >> sonnet_4x.txt
done
cd disagg_benchmarks
rm -rf results
mkdir results
default_qps=1
default_output_len=1
benchmark $default_qps $default_output_len
}
main "$@"
#!/bin/bash
# Requirement: 8x H100 GPUs.
# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV
# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests
# Resource: 8x H100
# Approaches:
# 1. Chunked prefill: 1 vllm instance with tp=8
# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4
# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance
# Prefilling instance: max_output_token=1
# Decoding instance: force the input tokens be the same across requests to bypass prefilling
set -ex
kill_gpu_processes() {
# kill all processes on GPU.
pgrep pt_main_thread | xargs -r kill -9
pgrep python3 | xargs -r kill -9
for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done
sleep 1
}
wait_for_server() {
# wait for vllm server to start
# return 1 if vllm server crashes
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
launch_chunked_prefill() {
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
# disagg prefill
CUDA_VISIBLE_DEVICES=0 python3 \
-m vllm.entrypoints.openai.api_server \
--model $model \
--port 8100 \
--max-model-len 10000 \
--enable-chunked-prefill \
--gpu-memory-utilization 0.6 &
CUDA_VISIBLE_DEVICES=1 python3 \
-m vllm.entrypoints.openai.api_server \
--model $model \
--port 8200 \
--max-model-len 10000 \
--enable-chunked-prefill \
--gpu-memory-utilization 0.6 &
wait_for_server 8100
wait_for_server 8200
python3 round_robin_proxy.py &
sleep 1
}
launch_disagg_prefill() {
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
# disagg prefill
CUDA_VISIBLE_DEVICES=0 python3 \
-m vllm.entrypoints.openai.api_server \
--model $model \
--port 8100 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
CUDA_VISIBLE_DEVICES=1 python3 \
-m vllm.entrypoints.openai.api_server \
--model $model \
--port 8200 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
wait_for_server 8100
wait_for_server 8200
python3 disagg_prefill_proxy_server.py &
sleep 1
}
benchmark() {
results_folder="./results"
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
dataset_name="sonnet"
dataset_path="../sonnet_4x.txt"
num_prompts=100
qps=$1
prefix_len=50
input_len=1024
output_len=$2
tag=$3
python3 ../benchmark_serving.py \
--backend vllm \
--model $model \
--dataset-name $dataset_name \
--dataset-path $dataset_path \
--sonnet-input-len $input_len \
--sonnet-output-len "$output_len" \
--sonnet-prefix-len $prefix_len \
--num-prompts $num_prompts \
--port 8000 \
--save-result \
--result-dir $results_folder \
--result-filename "$tag"-qps-"$qps".json \
--request-rate "$qps"
sleep 2
}
main() {
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
(which jq) || (apt-get -y install jq)
(which socat) || (apt-get -y install socat)
pip install quart httpx matplotlib aiohttp
cd "$(dirname "$0")"
cd ..
# create sonnet-4x.txt so that we can sample 2048 tokens for input
echo "" > sonnet_4x.txt
for _ in {1..4}
do
cat sonnet.txt >> sonnet_4x.txt
done
cd disagg_benchmarks
rm -rf results
mkdir results
default_output_len=6
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
launch_chunked_prefill
for qps in 2 4 6 8; do
benchmark $qps $default_output_len chunked_prefill
done
kill_gpu_processes
launch_disagg_prefill
for qps in 2 4 6 8; do
benchmark $qps $default_output_len disagg_prefill
done
kill_gpu_processes
python3 visualize_benchmark_results.py
}
main "$@"
import os
import aiohttp
from quart import Quart, make_response, request
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
app = Quart(__name__)
async def forward_request(url, data):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
async with session.post(url=url, json=data,
headers=headers) as response:
if response.status == 200:
# if response.headers.get('Transfer-Encoding') == 'chunked':
if True:
async for chunk_bytes in response.content.iter_chunked(
1024):
yield chunk_bytes
else:
content = await response.read()
yield content
@app.route('/v1/completions', methods=['POST'])
async def handle_request():
try:
original_request_data = await request.get_json()
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request['max_tokens'] = 1
# finish prefill
async for _ in forward_request('http://localhost:8100/v1/completions',
prefill_request):
continue
# return decode
generator = forward_request('http://localhost:8200/v1/completions',
original_request_data)
response = await make_response(generator)
response.timeout = None
return response
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))
if __name__ == '__main__':
app.run(port=8000)
import asyncio
import itertools
import aiohttp
from aiohttp import web
class RoundRobinProxy:
def __init__(self, target_ports):
self.target_ports = target_ports
self.port_cycle = itertools.cycle(self.target_ports)
async def handle_request(self, request):
target_port = next(self.port_cycle)
target_url = f"http://localhost:{target_port}{request.path_qs}"
async with aiohttp.ClientSession() as session:
try:
# Forward the request
async with session.request(
method=request.method,
url=target_url,
headers=request.headers,
data=request.content,
) as response:
# Start sending the response
resp = web.StreamResponse(status=response.status,
headers=response.headers)
await resp.prepare(request)
# Stream the response content
async for chunk in response.content.iter_any():
await resp.write(chunk)
await resp.write_eof()
return resp
except Exception as e:
return web.Response(text=f"Error: {str(e)}", status=500)
async def main():
proxy = RoundRobinProxy([8100, 8200])
app = web.Application()
app.router.add_route('*', '/{path:.*}', proxy.handle_request)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, 'localhost', 8000)
await site.start()
print("Proxy server started on http://localhost:8000")
# Keep the server running
await asyncio.Event().wait()
if __name__ == '__main__':
asyncio.run(main())
import json
import matplotlib.pyplot as plt
import pandas as pd
if __name__ == "__main__":
data = []
for name in ['disagg_prefill', 'chunked_prefill']:
for qps in [2, 4, 6, 8]:
with open(f"results/{name}-qps-{qps}.json") as f:
x = json.load(f)
x['name'] = name
x['qps'] = qps
data.append(x)
df = pd.DataFrame.from_dict(data)
dis_df = df[df['name'] == 'disagg_prefill']
chu_df = df[df['name'] == 'chunked_prefill']
plt.style.use('bmh')
plt.rcParams['font.size'] = 20
for key in [
'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms',
'median_itl_ms', 'p99_itl_ms'
]:
fig, ax = plt.subplots(figsize=(11, 7))
plt.plot(dis_df['qps'],
dis_df[key],
label='disagg_prefill',
marker='o',
linewidth=4)
plt.plot(chu_df['qps'],
chu_df[key],
label='chunked_prefill',
marker='o',
linewidth=4)
ax.legend()
ax.set_xlabel('QPS')
ax.set_ylabel(key)
ax.set_ylim(bottom=0)
fig.savefig(f'results/{key}.png')
plt.close(fig)
#!/bin/bash
# This file demonstrates the example usage of disaggregated prefilling
# We will launch 2 vllm instances (1 for prefill and 1 for decode),
# and then transfer the KV cache between them.
echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧"
sleep 1
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'cleanup' INT
# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9
pkill -f python
echo "Cleanup complete. Exiting."
exit 0
}
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
# install quart first -- required for disagg prefill proxy serve
if python3 -c "import quart" &> /dev/null; then
echo "Quart is already installed."
else
echo "Quart is not installed. Installing..."
python3 -m pip install quart
fi
# a function that waits vLLM server to start
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# You can also adjust --kv-ip and --kv-port for distributed inference.
# prefilling instance, which is the KV producer
CUDA_VISIBLE_DEVICES=0 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 8100 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' &
# decoding instance, which is the KV consumer
CUDA_VISIBLE_DEVICES=1 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 8200 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' &
# wait until prefill and decode instances are ready
wait_for_server 8100
wait_for_server 8200
# launch a proxy server that opens the service at port 8000
# the workflow of this proxy:
# - send the request to prefill vLLM instance (port 8100), change max_tokens
# to 1
# - after the prefill vLLM finishes prefill, send the request to decode vLLM
# instance
# NOTE: the usage of this API is subject to change --- in the future we will
# introduce "vllm connect" to connect between prefill and decode instances
python3 ../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py &
sleep 1
# serve two example requests
output1=$(curl -X POST -s http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"prompt": "San Francisco is a",
"max_tokens": 10,
"temperature": 0
}')
output2=$(curl -X POST -s http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"prompt": "Santa Clara is a",
"max_tokens": 10,
"temperature": 0
}')
# Cleanup commands
pgrep python | xargs kill -9
pkill -f python
echo ""
sleep 1
# Print the outputs of the curl requests
echo ""
echo "Output of first request: $output1"
echo "Output of second request: $output2"
echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉"
echo ""
import os
import subprocess
import sys
import time
from subprocess import Popen
import pytest
import requests
import torch
# Fixture to set up environment variables and teardown servers after tests
@pytest.fixture(scope="module", autouse=True)
def setup_servers():
if torch.cuda.device_count() < 4:
pytest.skip("Skipping test: fewer than 4 GPUs available")
# Set up environment variables
VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'",
shell=True).decode().strip()
os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP
# Start prefill instance
prefill_cmd = [
sys.executable,
"-m",
"vllm.entrypoints.openai.api_server",
"--model",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"--port",
"8100",
"--gpu-memory-utilization",
"0.5",
"--max-model-len",
"1000",
"--kv-transfer-config",
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'\
'"kv_rank":0,"kv_parallel_size":2}',
]
prefill_env = os.environ.copy()
prefill_env["CUDA_VISIBLE_DEVICES"] = "0"
prefill_proc = Popen(prefill_cmd, env=prefill_env)
# Start decode instance
decode_cmd = [
sys.executable,
"-m",
"vllm.entrypoints.openai.api_server",
"--model",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"--port",
"8200",
"--gpu-memory-utilization",
"0.5",
"--max-model-len",
"1000",
"--kv-transfer-config",
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'\
'"kv_rank":1,"kv_parallel_size":2}',
]
decode_env = os.environ.copy()
decode_env["CUDA_VISIBLE_DEVICES"] = "1"
decode_proc = Popen(decode_cmd, env=decode_env)
# Wait for servers to be ready
assert wait_for_server(8100), "Prefill server did not start in time"
assert wait_for_server(8200), "Decode server did not start in time"
# Yield to the test function and handle teardown after tests
yield
# Cleanup: kill the processes
prefill_proc.terminate()
decode_proc.terminate()
# Additional cleanup if needed
prefill_proc.wait()
decode_proc.wait()
# Helper function to wait for server
def wait_for_server(port, timeout=240):
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(f"http://localhost:{port}/v1/completions")
if response.status_code in [200, 405]:
return True
except requests.ConnectionError:
time.sleep(1)
return False
# Test function to send curl requests and validate responses
@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"])
def test_disaggregated_prefilling(prompt):
# Send to prefill
response = requests.post("http://localhost:8100/v1/completions",
headers={"Content-Type": "application/json"},
json={
"model":
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"prompt": prompt,
"max_tokens": 1,
"temperature": 0
})
assert response.status_code == 200
# Send to decode
response = requests.post("http://localhost:8200/v1/completions",
headers={"Content-Type": "application/json"},
json={
"model":
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"prompt": prompt,
"max_tokens": 10,
"temperature": 0
})
assert response.status_code == 200
import subprocess
import sys
import pytest
import torch
def run_python_script(script_name, timeout):
script_name = f'kv_transfer/{script_name}'
try:
# Start both processes asynchronously using Popen
process0 = subprocess.Popen(
[sys.executable, script_name],
env={"RANK":
"0"}, # Set the RANK environment variable for process 0
stdout=sys.stdout, # Pipe stdout to current stdout
stderr=sys.stderr, # Pipe stderr to current stderr
)
process1 = subprocess.Popen(
[sys.executable, script_name],
env={"RANK":
"1"}, # Set the RANK environment variable for process 1
stdout=sys.stdout, # Pipe stdout to current stdout
stderr=sys.stderr, # Pipe stderr to current stderr
)
# Wait for both processes to complete, with a timeout
process0.wait(timeout=timeout)
process1.wait(timeout=timeout)
# Check the return status of both processes
if process0.returncode != 0:
pytest.fail(
f"Test {script_name} failed for RANK=0, {process0.returncode}")
if process1.returncode != 0:
pytest.fail(
f"Test {script_name} failed for RANK=1, {process1.returncode}")
except subprocess.TimeoutExpired:
# If either process times out, terminate both and fail the test
process0.terminate()
process1.terminate()
pytest.fail(f"Test {script_name} timed out")
except Exception as e:
pytest.fail(f"Test {script_name} failed with error: {str(e)}")
# Define the test cases using pytest's parametrize
@pytest.mark.parametrize(
"script_name,timeout",
[
("test_lookup_buffer.py",
60), # Second test case with a 60-second timeout
("test_send_recv.py", 120) # First test case with a 120-second timeout
])
def test_run_python_script(script_name, timeout):
# Check the number of GPUs
if torch.cuda.device_count() < 2:
pytest.skip(
f"Skipping test {script_name} because <2 GPUs are available")
# Run the test if there are at least 2 GPUs
run_python_script(script_name, timeout)
import os
import random
import torch
from tqdm import tqdm
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer)
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
# TODO: the test depends on a lot of fields in the current implementation.
# We should have standard interface instead direct field access
def test_run(my_rank, buffer, device):
# buffer should be empty in the beginning
if my_rank == 0:
assert buffer.buffer_size == 0
assert len(buffer.buffer) == 0
print("My rank: %d, device: %s" % (my_rank, device))
# insert
tokens = torch.tensor([1, 2, 3]).to(device)
roi = (tokens > 0)
if my_rank == 0:
key = 2.0 * torch.ones([5, 6]).to(device)
value = 3.0 * torch.ones([5, 6]).to(device)
placeholder = torch.tensor([1]).to(device)
buffer.insert(tokens, roi, key, value, placeholder)
torch.distributed.barrier()
# drop_select
if my_rank == 1:
tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi)
assert torch.allclose(tokens, tok)
assert torch.allclose(roi, roi_)
assert torch.allclose(key, 2.0 * torch.ones([5, 6], device=device))
assert torch.allclose(value, 3.0 * torch.ones([5, 6], device=device))
torch.distributed.barrier()
if my_rank == 0:
assert buffer.buffer_size == 0
assert len(buffer.buffer) == 0
print("Test run passed!")
def stress_test(my_rank, buf, device):
torch.distributed.barrier()
torch.manual_seed(100)
reqs = [
(
torch.rand(100).to(device), # tokens
torch.ones(100).bool().to(device), # roi
torch.rand(100).to(device), # key
torch.rand(100).to(device), # value
torch.rand(100).to(device), # hidden
) for i in tqdm(range(200))
]
random.seed(my_rank)
random.shuffle(reqs)
torch.distributed.barrier()
n = 0
# the buffer size can only store 100 reqs
# so the sender will occasionally block to wait for the receiver.
for req in tqdm(reqs):
if my_rank == 0:
buf.insert(*req)
else:
tok, roi, k, v, h = req
tok_, roi_, k_, v_, h_ = buf.drop_select(tok, roi)
if tok_ is None:
assert roi_ is None
assert k_ is None
assert v_ is None
assert h_ is None
n += 1
else:
assert torch.allclose(tok, tok_)
assert torch.allclose(roi, roi_)
assert torch.allclose(k, k_)
assert torch.allclose(v, v_)
assert torch.allclose(h, h_)
print('Rank %d done' % my_rank)
torch.distributed.barrier()
if my_rank == 0:
x = torch.tensor([0])
torch.distributed.recv(x, 1)
# the # of None received is the kv that are not selected
assert x.item() == len(buf.buffer)
# and the size of the buffer should be 2000 * buffer len
print(buf.buffer_size)
assert buf.buffer_size == 1700 * len(buf.buffer)
else:
torch.distributed.send(torch.tensor([n]), 0)
print("Passed stress test!")
if __name__ == "__main__":
my_rank = int(os.environ['RANK'])
torch.distributed.init_process_group(
backend='gloo',
init_method='tcp://localhost:12398',
world_size=2,
rank=my_rank,
)
print("initialized! My rank is %d" % my_rank)
config = KVTransferConfig(
kv_connector='PyNcclConnector',
kv_buffer_device='cuda',
kv_buffer_size=1e9,
kv_rank=my_rank,
kv_role="kv_both", # this arg doesn't matter in this test
kv_parallel_size=2,
kv_ip="127.0.0.1",
kv_port=12345,
)
data_pipe = PyNcclPipe(
local_rank=my_rank,
config=config,
device="cuda",
port_offset=0,
)
cpu_pipe = PyNcclPipe(
local_rank=my_rank,
config=config,
device="cpu",
port_offset=1,
)
buffer = SimpleBuffer(cpu_pipe, data_pipe, 170000)
test_run(my_rank, buffer, data_pipe.device)
stress_test(my_rank, buffer, data_pipe.device)
buffer.close()
data_pipe.close()
cpu_pipe.close()
print('Done')
#!/bin/bash
RANK=0 python test_lookup_buffer.py &
RANK=1 python test_lookup_buffer.py &
\ No newline at end of file
import os
import time
from typing import List
import torch
from tqdm import tqdm
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
def test_run(my_rank, pipe):
# test run
x = torch.tensor([1]).to(pipe.device)
y = torch.tensor([[2., 3., 4., 8.]]).to(pipe.device)
if my_rank == 0:
pipe.send_tensor(x)
print("sent tensor x")
pipe.send_tensor(y)
print("sent tensor y")
x2 = pipe.recv_tensor()
print("received x2 = ", x2)
y2 = pipe.recv_tensor()
print("received y2 = ", x2)
else:
x2 = pipe.recv_tensor()
print("received x2 = ", x2)
y2 = pipe.recv_tensor()
print("received y2 = ", x2)
pipe.send_tensor(x)
print("sent tensor x")
pipe.send_tensor(y)
print("sent tensor y")
assert torch.allclose(x, x2)
assert torch.allclose(y, y2)
def stress_test(my_rank, pipe):
torch.distributed.barrier()
tensors: List[torch.Tensor] = []
torch.manual_seed(0)
for i in tqdm(range(500)):
mean = torch.rand(1).item() * 100
std = torch.rand(1).item() * 100
size = torch.randint(900, 1000, (2, ))
x = torch.normal(mean * 1.0, std * 1.0,
size=size.tolist()).to(pipe.device)
# 5% probability of sending a None
if torch.rand(1).item() < 0.05:
tensors.append(None)
tensors.append(None)
tensors.append(None)
else:
tensors.append(x)
tensors.append(x.mean().unsqueeze(0))
tensors.append(x.std().unsqueeze(0))
torch.distributed.barrier()
for i in tqdm(range(500)):
if my_rank == int((i % 10) > 3):
pipe.send_tensor(tensors[3 * i])
pipe.send_tensor(tensors[3 * i + 1])
pipe.send_tensor(tensors[3 * i + 2])
else:
x = pipe.recv_tensor()
mean = pipe.recv_tensor()
std = pipe.recv_tensor()
if x is None:
assert mean is None
assert std is None
else:
assert torch.allclose(x, tensors[3 * i])
assert x.mean() == mean[0]
assert x.std() == std[0]
torch.distributed.barrier()
def latency_test(my_rank, pipe, nelement, ntensor):
latencies = []
torch.distributed.barrier()
for i in tqdm(range(500)):
tensors = []
if my_rank == 0:
# create tensor
tensors = [
torch.rand(nelement).to(pipe.device) for _ in range(ntensor)
]
torch.distributed.barrier()
if my_rank == 0:
t = torch.tensor([time.time()],
dtype=torch.float64).to(pipe.device)
for tensor in tensors:
pipe.send_tensor(tensor)
pipe.send_tensor(t)
else:
for _ in range(ntensor):
pipe.recv_tensor()
t = pipe.recv_tensor()
latencies.append(time.time() - t.item())
torch.distributed.barrier()
print('Latency test passed.')
print('Latency:', torch.tensor(latencies).mean().item() * 1000, 'ms')
if __name__ == "__main__":
my_rank = int(os.environ['RANK'])
torch.distributed.init_process_group(
backend='gloo',
init_method='tcp://localhost:12398',
world_size=2,
rank=my_rank,
)
config = KVTransferConfig(
kv_connector='PyNcclConnector',
kv_buffer_device='cuda',
kv_buffer_size=1e9,
kv_rank=my_rank,
kv_role="kv_both", # this arg doesn't matter in this test
kv_parallel_size=2,
kv_ip="127.0.0.1",
kv_port=12345,
)
pipe = PyNcclPipe(
local_rank=my_rank,
config=config,
)
test_run(my_rank, pipe)
stress_test(my_rank, pipe)
# Use this function if you want to test the latency of pipe impl.
# latency_test(my_rank, pipe, 1024 * 8 * 128, 80)
#!/bin/bash
RANK=0 python3 test_send_recv.py &
RANK=1 python3 test_send_recv.py &
\ No newline at end of file
...@@ -2052,6 +2052,88 @@ class ObservabilityConfig: ...@@ -2052,6 +2052,88 @@ class ObservabilityConfig:
f"installed. Original error:\n{otel_import_error_traceback}") f"installed. Original error:\n{otel_import_error_traceback}")
class KVTransferConfig(BaseModel):
"""Configuration for distributed KV cache transfer."""
# The KV connector for vLLM to transmit KV caches between vLLM instances.
kv_connector: Optional[str] = None
# The device used by kv connector to buffer the KV cache.
# Currently only support 'cuda'.
kv_buffer_device: Optional[str] = "cuda"
# The buffer size for TorchDistributedConnector. Measured in number of
# bytes. Recommended value: 1e9 (about 1GB).
kv_buffer_size: float = 1e9
# Whether this vLLM instance produces, consumes KV cache, or both. Choices
# are 'kv_producer', 'kv_consumer', and 'both'.
kv_role: Optional[str] = None
# The rank of this vLLM instance in the KV cache transfer. Typical value:
# 0 for prefill instance, 1 for decode instance.
# Currently only 1P1D is supported.
kv_rank: Optional[int] = None
# The number of parallel instances for KV cache transfer. For
# PyNcclConnector, this should be 2.
kv_parallel_size: int = 1
# The KV connector ip, used to build distributed connection
kv_ip: str = "127.0.0.1"
# The KV connector port, used to build distributed connection
kv_port: int = 14579
@classmethod
def from_cli(cls, cli_value: str) -> "KVTransferConfig":
"""Parse the CLI value for the compilation config."""
return KVTransferConfig.model_validate_json(cli_value)
def model_post_init(self, __context: Any) -> None:
if all([
self.kv_connector is not None,
self.kv_connector != "PyNcclConnector"
]):
raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. "
f"Supported connectors are "
f"`PyNcclConnector`.")
if self.kv_role is not None and self.kv_role not in [
"kv_producer", "kv_consumer", "kv_both"
]:
raise ValueError(
f"Unsupported kv_role: {self.kv_role}. "
f"Supported roles are `kv_producer`, `kv_consumer`, "
f"and `kv_both`")
if self.kv_connector is not None and self.kv_role is None:
raise ValueError("Please specify kv_disagg_role when kv_connector "
"is set, supported roles are `kv_producer`, "
"`kv_consumer`, and `kv_both`")
@property
def is_kv_transfer_instance(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in ["kv_producer", "kv_consumer", "kv_both"]
@property
def need_kv_parallel_group(self) -> bool:
# for those database-based connector, vLLM does not need to create
# parallel group, and in that case the kv parallel size will be 1.
return self.kv_connector is not None and self.kv_parallel_size > 1
@property
def is_kv_producer(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in ["kv_producer", "kv_both"]
@property
def is_kv_consumer(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in ["kv_consumer", "kv_both"]
class CompilationLevel: class CompilationLevel:
# constants for the levels of the compilation process # constants for the levels of the compilation process
NO_COMPILATION = 0 NO_COMPILATION = 0
...@@ -2317,6 +2399,8 @@ class VllmConfig: ...@@ -2317,6 +2399,8 @@ class VllmConfig:
quant_config: Optional[QuantizationConfig] = None quant_config: Optional[QuantizationConfig] = None
compilation_config: CompilationConfig = field(default=None, compilation_config: CompilationConfig = field(default=None,
init=True) # type: ignore init=True) # type: ignore
kv_transfer_config: KVTransferConfig = field(default=None,
init=True) # type: ignore
@staticmethod @staticmethod
def _get_quantization_config( def _get_quantization_config(
......
# Distributed KV cache transfer
This folder implements distributed KV cache transfer across vLLM instances.
Currently the main usecase is for disaggregated prefilling.
## Abstractions
The KV cache transfer contains three layer of abstractions:
- KV pipe: a FIFO pipe for torch.tensor transmission. Key APIs: `send_tensor` and `recv_tensor`.
- KV lookup buffer: a lookup buffer for KV caches. Key: the tokens, value: the KV caches (and/or hidden states). Key APIs: `insert` and `drop_select` (similar to SQL semantics).
- KV connector: a connector that connects the KV pipe and KV lookup buffer to vLLM. Key APIs: `send_kv_caches_and_hidden_states` and `recv_kv_caches_and_hidden_states`.
Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer.
NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed
communication service already supports key-value-based lookup (like redis or
RDMA database).
NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates.
## Disaggregated prefilling
The example usage is in [this file](../../../examples/disaggregated_prefill.sh).
Here is the diagram of how we run disaggretgated prefilling.
![Disaggregated prefill workflow](./disagg_prefill_workflow.jpg)
"""
KVConnectorBase Class for Distributed KV Cache & Hidden State communication
The class provides two primary abstract methods:
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Tuple, Union
import torch
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
class KVConnectorBase(ABC):
"""
Abstract base class for a KV connector.
The class provides two primary abstract methods:
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
"""
@abstractmethod
def __init__(
self,
rank: int,
local_rank: int,
config: "VllmConfig",
):
raise NotImplementedError
@abstractmethod
def close(self) -> None:
"""Close the buffer and release resources.
This method is responsible for cleaning up resources related to the
connector when it is no longer needed.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
"""
Send KV caches and hidden states to the connector.
This method processes the input tokens, KV caches, and
hidden/intermediate states for a given model and sends the data to the
decode instance.
Args:
model_executable (torch.nn.Module): The model executable containing
start and end layer information.
model_input (ModelInputForGPUWithSamplingMetadata): The input
metadata from vLLM.
kv_caches (List[torch.Tensor]): List of KV caches (keys and values)
for each layer.
hidden_or_intermediate_states (Union[torch.Tensor,
IntermediateTensors]):
The hidden or intermediate states associated with the tokens.
Returns:
None
"""
raise NotImplementedError
@abstractmethod
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
"""
Receive KV caches and hidden states from the connector.
This method attempts to retrieve KV caches and hidden states for input
tokens. If all required KV caches and hidden states are received, it
will bypass model input, else it will fall back to normal vLLM model
forwarding.
Args:
model_executable (torch.nn.Module):
The model executable from vLLM modelrunner.
model_input (ModelInputForGPUWithSamplingMetadata):
The model input from vLLM modelrunner.
kv_caches (List[torch.Tensor]):
List of KV caches for each layer.
Returns:
- hidden_or_intermediate_states (torch.Tensor or
IntermediateTensors):
Concatenated hidden states if all required data is retrieved,
otherwise `None`.
- bypass_model_exec (bool):
Indicates whether the model execution can be skipped (True) or
needs to be redone (False).
- model_input (ModelInputForGPUWithSamplingMetadata):
Optionally adjusted input metadata for re-execution when
`bypass_model_exec=False`.
"""
raise NotImplementedError
from typing import TYPE_CHECKING
from .base import KVConnectorBase
if TYPE_CHECKING:
from vllm.config import VllmConfig
class KVConnectorFactory:
@staticmethod
def create_connector(rank: int, local_rank: int,
config: "VllmConfig") -> KVConnectorBase:
if config.kv_transfer_config.kv_connector == 'PyNcclConnector':
from .simple_connector import SimpleConnector
return SimpleConnector(rank, local_rank, config)
else:
raise ValueError(f"Unsupported connector type: "
f"{config.kv_connector}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment