Unverified Commit d1911020 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[P/D] NIXL Integration (#17751)


Signed-off-by: default avatarApostaC <yihua98@uchicago.edu>
Signed-off-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: default avatarrshaw@neuralmagic.com <robertgshaw2@gmail.com>
Signed-off-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarNick Hill <nhill@redhat.com>
Signed-off-by: default avatarBrent Salisbury <bsalisbu@redhat.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: default avatarApostaC <yihua98@uchicago.edu>
Co-authored-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
Co-authored-by: default avatarTyler Michael Smith <tysmith@redhat.com>
Co-authored-by: default avatarBrent Salisbury <bsalisbu@redhat.com>
parent 05a4324f
...@@ -214,6 +214,7 @@ steps: ...@@ -214,6 +214,7 @@ steps:
- pytest -v -s v1/worker - pytest -v -s v1/worker
- pytest -v -s v1/structured_output - pytest -v -s v1/structured_output
- pytest -v -s v1/spec_decode - pytest -v -s v1/spec_decode
- pytest -v -s v1/kv_connector/unit
- pytest -v -s v1/test_serial_utils.py - pytest -v -s v1/test_serial_utils.py
- pytest -v -s v1/test_stats.py - pytest -v -s v1/test_stats.py
- pytest -v -s v1/test_utils.py - pytest -v -s v1/test_utils.py
......
...@@ -870,7 +870,7 @@ def test_kv_connector_basic(): ...@@ -870,7 +870,7 @@ def test_kv_connector_basic():
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method") scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = ( scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS) NUM_MATCHED_NEW_TOKENS, False)
###################################################### ######################################################
# FIRST SET OF REQUESTS - External Hit Only # FIRST SET OF REQUESTS - External Hit Only
...@@ -981,7 +981,7 @@ def test_kv_connector_unable_to_allocate(): ...@@ -981,7 +981,7 @@ def test_kv_connector_unable_to_allocate():
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method") scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = ( scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS) NUM_MATCHED_NEW_TOKENS, False)
# Create two requests. The second request will not be able to # Create two requests. The second request will not be able to
# allocate slots because it will not have enough blocks. # allocate slots because it will not have enough blocks.
...@@ -1060,7 +1060,7 @@ def test_kv_connector_handles_preemption(): ...@@ -1060,7 +1060,7 @@ def test_kv_connector_handles_preemption():
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
scheduler.connector.get_num_new_matched_tokens = Mock(name="method") scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = ( scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS) NUM_MATCHED_NEW_TOKENS, False)
# Create two requests. # Create two requests.
# Both can be scheduled at first, but the second request # Both can be scheduled at first, but the second request
......
#!/bin/bash
set -xe
# Models to run
MODELS=(
"Qwen/Qwen3-0.6B"
)
# Number of prefill and decode instances to create
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-2} # Default to 2
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
# Waits for vLLM to start.
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Function to clean up previous instances
cleanup_instances() {
echo "Cleaning up any running vLLM instances..."
pkill -f "vllm serve" || true
sleep 2
}
# Handle to get model-specific arguments for deepseek
get_model_args() {
local model_name=$1
local extra_args=""
if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then
extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code"
fi
echo "$extra_args"
}
# Function to run tests for a specific model
run_tests_for_model() {
local model_name=$1
echo "================================"
echo "Testing model: $model_name"
echo "================================"
# Get model-specific arguments
local model_args=$(get_model_args "$model_name")
# Arrays to store all hosts and ports
PREFILL_HOSTS=()
PREFILL_PORTS=()
DECODE_HOSTS=()
DECODE_PORTS=()
# Start prefill instances
for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do
# Calculate GPU ID - we'll distribute across available GPUs
GPU_ID=$((i % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)))
# Calculate port number (base port + instance number)
PORT=$((8100 + i))
# Calculate side channel port
SIDE_CHANNEL_PORT=$((5559 + i))
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
# Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \
--port $PORT \
--enforce-eager \
--disable-log-requests \
--gpu-memory-utilization 0.2 \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
else
FULL_CMD="$BASE_CMD"
fi
eval "$FULL_CMD &"
# Store host and port for proxy configuration
PREFILL_HOSTS+=("localhost")
PREFILL_PORTS+=($PORT)
done
# Start decode instances
for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do
# Calculate GPU ID - we'll distribute across available GPUs, starting from after prefill GPUs
GPU_ID=$(((i + NUM_PREFILL_INSTANCES) % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)))
# Calculate port number (base port + instance number)
PORT=$((8200 + i))
# Calculate side channel port
SIDE_CHANNEL_PORT=$((5659 + i))
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
# Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \
--port $PORT \
--enforce-eager \
--disable-log-requests \
--gpu-memory-utilization 0.2 \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
else
FULL_CMD="$BASE_CMD"
fi
eval "$FULL_CMD &"
# Store host and port for proxy configuration
DECODE_HOSTS+=("localhost")
DECODE_PORTS+=($PORT)
done
# Wait for all instances to start
for PORT in "${PREFILL_PORTS[@]}"; do
echo "Waiting for prefill instance on port $PORT to start..."
wait_for_server $PORT
done
for PORT in "${DECODE_PORTS[@]}"; do
echo "Waiting for decode instance on port $PORT to start..."
wait_for_server $PORT
done
# Build the command for the proxy server with all the hosts and ports
PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192"
# Add all prefill hosts and ports
PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}"
PROXY_CMD+=" --prefiller-ports ${PREFILL_PORTS[@]}"
# Add all decode hosts and ports
PROXY_CMD+=" --decoder-hosts ${DECODE_HOSTS[@]}"
PROXY_CMD+=" --decoder-ports ${DECODE_PORTS[@]}"
# Start the proxy server
echo "Starting proxy server with command: $PROXY_CMD"
$PROXY_CMD &
# Wait for the proxy to start
sleep 5
# Run lm eval for this model
echo "Running tests for $model_name"
TEST_MODEL=$model_name python -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py
# Clean up before running next model
cleanup_instances
sleep 3
}
# Run tests for each model
for model in "${MODELS[@]}"; do
run_tests_for_model "$model"
done
echo "All tests completed!"
#!/bin/bash
set -xe
# Models to run
MODELS=(
"Qwen/Qwen3-0.6B"
)
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
# Waits for vLLM to start.
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Function to clean up previous instances
cleanup_instances() {
echo "Cleaning up any running vLLM instances..."
pkill -f "vllm serve" || true
sleep 2
}
# Handle to get model-specific arguments for deepseek
get_model_args() {
local model_name=$1
local extra_args=""
if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then
extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code"
fi
echo "$extra_args"
}
# Function to run tests for a specific model
run_tests_for_model() {
local model_name=$1
echo "================================"
echo "Testing model: $model_name"
echo "================================"
# Get model-specific arguments
local model_args=$(get_model_args "$model_name")
# Start prefill instance
PREFILL_PORT=8001
BASE_CMD="CUDA_VISIBLE_DEVICES=0 VLLM_NIXL_SIDE_CHANNEL_PORT=5559 vllm serve $model_name \
--port $PREFILL_PORT \
--enforce-eager \
--disable-log-requests \
--gpu-memory-utilization 0.2 \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
else
FULL_CMD="$BASE_CMD"
fi
eval "$FULL_CMD &"
# Start decode instance
DECODE_PORT=8002
# Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=1 VLLM_NIXL_SIDE_CHANNEL_PORT=6000 vllm serve $model_name \
--port $DECODE_PORT \
--enforce-eager \
--disable-log-requests \
--gpu-memory-utilization 0.2 \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
else
FULL_CMD="$BASE_CMD"
fi
eval "$FULL_CMD &"
# Wait for all instances to start
echo "Waiting for prefill instance on port $PORT to start..."
wait_for_server $PREFILL_PORT
echo "Waiting for decode instance on port $PORT to start..."
wait_for_server $DECODE_PORT
# Build the command for the proxy server with all the hosts and ports
PROXY_PORT=8192
PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port $PROXY_PORT"
PROXY_CMD+=" --prefiller-ports ${PREFILL_PORT}"
PROXY_CMD+=" --decoder-ports ${DECODE_PORT}"
# Start the proxy server
echo "Starting proxy server with command: $PROXY_CMD"
$PROXY_CMD &
# Wait for the proxy to start
sleep 5
# Run lm eval for this model
echo "Running tests for $model_name"
PREFILL_PORT=$PREFILL_PORT DECODE_PORT=$DECODE_PORT PROXY_PORT=$PROXY_PORT python -m pytest -s -v ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_edge_cases.py
# Clean up before running next model
cleanup_instances
sleep 3
}
# Run tests for each model
for model in "${MODELS[@]}"; do
run_tests_for_model "$model"
done
echo "All tests completed!"
# SPDX-License-Identifier: Apache-2.0
import os
import lm_eval
import openai
BASE_URL = "http://localhost:8192/v1"
NUM_CONCURRENT = 100
TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
# Model-specific expected values
EXPECTED_VALUES = {
"Qwen/Qwen3-0.6B": 0.41,
}
SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501
# Get model name from environment variable
MODEL_NAME = os.environ.get("TEST_MODEL", "Qwen/Qwen3-0.6B")
def run_simple_prompt():
client = openai.OpenAI(api_key="EMPTY", base_url=BASE_URL)
completion = client.completions.create(model=MODEL_NAME,
prompt=SIMPLE_PROMPT)
print("-" * 50)
print(f"Completion results for {MODEL_NAME}:")
print(completion)
print("-" * 50)
def test_accuracy():
"""Run the end to end accuracy test."""
run_simple_prompt()
model_args = (f"model={MODEL_NAME},"
f"base_url={BASE_URL}/completions,"
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
results = lm_eval.simple_evaluate(
model="local-completions",
model_args=model_args,
tasks=TASK,
)
measured_value = results["results"][TASK][FILTER]
expected_value = EXPECTED_VALUES.get(MODEL_NAME)
if expected_value is None:
print(f"Warning: No expected value found for {MODEL_NAME}. "
"Skipping accuracy check.")
print(f"Measured value: {measured_value}")
return
assert (measured_value - RTOL < expected_value
and measured_value + RTOL > expected_value
), f"Expected: {expected_value} | Measured: {measured_value}"
# SPDX-License-Identifier: Apache-2.0
import os
import openai
PREFILL_PORT = os.getenv("PREFILL_PORT", None)
DECODE_PORT = os.getenv("DECODE_PORT", None)
PROXY_PORT = os.getenv("PROXY_PORT", None)
if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None:
raise ValueError(
"Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.")
LONG_PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result, when working on projects like vLLM we are able to meet many amazing people from various organizations like AMD, Google, NVIDIA, " # noqa: E501
PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result," # noqa: E501
SHORT_PROMPT = "Red Hat is "
def test_edge_cases():
# Set the OpenAI API key and base URL
decode_client = openai.OpenAI(
api_key="MY_KEY",
base_url=f"http://localhost:{DECODE_PORT}/v1",
)
prefill_client = openai.OpenAI(
api_key="MY_KEY",
base_url=f"http://localhost:{PREFILL_PORT}/v1",
)
proxy_client = openai.OpenAI(
api_key="MY_KEY",
base_url=f"http://localhost:{PROXY_PORT}/v1",
)
# Get the list of models
models = decode_client.models.list()
MODEL = models.data[0].id
# (1) Check that we can handle a very short prompt,
# less than the length of the block size.
completion = proxy_client.completions.create(model=MODEL,
prompt=SHORT_PROMPT,
temperature=0)
proxy_response = completion.choices[0].text
completion = prefill_client.completions.create(model=MODEL,
prompt=SHORT_PROMPT,
temperature=0)
prefill_response = completion.choices[0].text
print(f"SMALL PROMPT: {proxy_response=}")
assert proxy_response == prefill_response
# (2) Check that we can handle a full prefix cache
# hit on the D worker but not on the P worker.
# (2a): prime the D worker.
completion = decode_client.completions.create(model=MODEL,
prompt=PROMPT,
temperature=0)
decode_response = completion.choices[0].text
# (2b): send via the P/D setup
completion = proxy_client.completions.create(model=MODEL,
prompt=PROMPT,
temperature=0)
proxy_response = completion.choices[0].text
print(f"FULL CACHE HIT: {proxy_response=}")
assert proxy_response == decode_response
# (3) Check that we can handle a partial prefix cache
# hit on the D worker.
completion = proxy_client.completions.create(model=MODEL,
prompt=LONG_PROMPT,
temperature=0)
proxy_response = completion.choices[0].text
completion = prefill_client.completions.create(model=MODEL,
prompt=LONG_PROMPT,
temperature=0)
prefill_response = completion.choices[0].text
print(f"PARTIAL CACHE HIT: {proxy_response=}")
assert proxy_response == prefill_response
# SPDX-License-Identifier: Apache-2.0
import argparse
import itertools
import os
import uuid
from contextlib import asynccontextmanager
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from vllm.logger import init_logger
logger = init_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Lifespan context manager to handle startup and shutdown events.
"""
# Startup: Initialize client pools for prefiller and decoder services
app.state.prefill_clients = []
app.state.decode_clients = []
# Create prefill clients
for i, (host, port) in enumerate(global_args.prefiller_instances):
prefiller_base_url = f'http://{host}:{port}/v1'
app.state.prefill_clients.append({
'client':
httpx.AsyncClient(timeout=None, base_url=prefiller_base_url),
'host':
host,
'port':
port,
'id':
i
})
# Create decode clients
for i, (host, port) in enumerate(global_args.decoder_instances):
decoder_base_url = f'http://{host}:{port}/v1'
app.state.decode_clients.append({
'client':
httpx.AsyncClient(timeout=None, base_url=decoder_base_url),
'host':
host,
'port':
port,
'id':
i
})
# Initialize round-robin iterators
app.state.prefill_iterator = itertools.cycle(
range(len(app.state.prefill_clients)))
app.state.decode_iterator = itertools.cycle(
range(len(app.state.decode_clients)))
print(f"Initialized {len(app.state.prefill_clients)} prefill clients "
f"and {len(app.state.decode_clients)} decode clients.")
yield
# Shutdown: Close all clients
for client_info in app.state.prefill_clients:
await client_info['client'].aclose()
for client_info in app.state.decode_clients:
await client_info['client'].aclose()
# Update FastAPI app initialization to use lifespan
app = FastAPI(lifespan=lifespan)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
# For prefiller instances
parser.add_argument("--prefiller-hosts",
"--prefiller-host",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--prefiller-ports",
"--prefiller-port",
type=int,
nargs="+",
default=[8100])
# For decoder instances
parser.add_argument("--decoder-hosts",
"--decoder-host",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--decoder-ports",
"--decoder-port",
type=int,
nargs="+",
default=[8200])
args = parser.parse_args()
# Validate and pair hosts with ports
if len(args.prefiller_hosts) != len(args.prefiller_ports):
raise ValueError(
"Number of prefiller hosts must match number of prefiller ports")
if len(args.decoder_hosts) != len(args.decoder_ports):
raise ValueError(
"Number of decoder hosts must match number of decoder ports")
# Create tuples of (host, port) for each service type
args.prefiller_instances = list(
zip(args.prefiller_hosts, args.prefiller_ports))
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
return args
def get_next_client(app, service_type: str):
"""
Get the next client in round-robin fashion.
Args:
app: The FastAPI app instance
service_type: Either 'prefill' or 'decode'
Returns:
The next client to use
"""
if service_type == 'prefill':
client_idx = next(app.state.prefill_iterator)
return app.state.prefill_clients[client_idx]
elif service_type == 'decode':
client_idx = next(app.state.decode_iterator)
return app.state.decode_clients[client_idx]
else:
raise ValueError(f"Unknown service type: {service_type}")
async def send_request_to_service(client_info: dict, endpoint: str,
req_data: dict, request_id: str):
"""
Send a request to a service using a client from the pool.
"""
req_data = req_data.copy()
req_data['kv_transfer_params'] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None
}
req_data["stream"] = False
req_data["max_tokens"] = 1
if "stream_options" in req_data:
del req_data["stream_options"]
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
response = await client_info['client'].post(endpoint,
json=req_data,
headers=headers)
response.raise_for_status()
return response
async def stream_service_response(client_info: dict, endpoint: str,
req_data: dict, request_id: str):
"""
Asynchronously stream response from a service using a client from the pool.
"""
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
async with client_info['client'].stream("POST",
endpoint,
json=req_data,
headers=headers) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
@app.post("/v1/completions")
async def handle_completions(request: Request):
try:
req_data = await request.json()
request_id = str(uuid.uuid4())
# Get the next prefill client in round-robin fashion
prefill_client_info = get_next_client(request.app, 'prefill')
# Send request to prefill service
response = await send_request_to_service(prefill_client_info,
"/completions", req_data,
request_id)
# Extract the needed fields
response_json = response.json()
kv_transfer_params = response_json.get('kv_transfer_params', {})
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params
# Get the next decode client in round-robin fashion
decode_client_info = get_next_client(request.app, 'decode')
logger.debug("Using %s %s", prefill_client_info, decode_client_info)
# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(decode_client_info,
"/completions",
req_data,
request_id=request_id):
yield chunk
return StreamingResponse(generate_stream(),
media_type="application/json")
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
" - completions endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@app.get("/healthcheck")
async def healthcheck():
"""Simple endpoint to check if the server is running."""
return {
"status": "ok",
"prefill_instances": len(app.state.prefill_clients),
"decode_instances": len(app.state.decode_clients)
}
if __name__ == '__main__':
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)
# SPDX-License-Identifier: Apache-2.0
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorMetadata)
from .utils import create_request, create_scheduler, create_vllm_config
def test_basic_inferface():
"""Unit test for basic NixlConnector interface functionality."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
request_id = request.request_id
scheduler.add_request(request)
# Remote Prefill, triggers NixlConnectorMetdata.
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, NixlConnectorMetadata)
assert len(kv_connector_metadata.requests) == 1
assert request_id in kv_connector_metadata.requests
req_meta = kv_connector_metadata.requests[request_id]
for block_id, block in zip(
req_meta.local_block_ids, scheduler.kv_cache_manager.
single_type_manager.req_to_blocks[request_id]):
assert block_id == block.block_id
def test_prompt_less_than_block_size():
"""
Test that we can handle case where prompt is < block.
In this case, the P worker will send empty remote_block_ids.
The D worker should not schedule an async read in this case,
since there is nothing to pull.
"""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# Half of a block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_TOKENS = int(BLOCK_SIZE * 0.5)
# Request will have 0 remote blocks.
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
num_remote_blocks=0)
scheduler.add_request(request)
scheduler_output = scheduler.schedule()
# This request should not have to read async.
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, NixlConnectorMetadata)
assert len(kv_connector_metadata.requests) == 0
# This request should be scheduled regularly.
assert len(scheduler_output.scheduled_new_reqs) == 1
# SPDX-License-Identifier: Apache-2.0
import copy
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
from vllm.v1.request import FinishReason, RequestStatus
from .utils import (assert_scheduler_empty, create_model_runner_output,
create_request, create_scheduler, create_vllm_config)
def test_basic_lifecycle():
"""Test lifecycle of a Remote Decode request."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1,
max_tokens=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1): Prefill.
# (1a): schedule()
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
# (1b): execute_model()
model_runner_output = create_model_runner_output(reqs=[request])
# (1c): update_from_output()
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
# Ensure the request is finished after 1 tokens.
assert request.is_finished()
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
output = engine_core_outputs.outputs[0]
assert output.finish_reason == FinishReason.LENGTH
assert output.kv_transfer_params is not None
# Request freed in Scheduler and in Persistent Batch ...
assert request_id in scheduler.finished_req_ids
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 0
# ... but blocks should not be freed.
blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
request_id]
for block in blocks:
assert block.ref_cnt == 1
# STEP (2): Send Finished to PB.
# (2a): schedule() - pass finished request to PB.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 1
assert request_id in scheduler_output.finished_req_ids
assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 0
assert len(scheduler.finished_req_ids) == 0
# (2b): execute_model()
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
# (2c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP (3): Finished sending.
# (3a): schedule() - pass finished request to PB.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 0
assert len(scheduler.finished_req_ids) == 0
# (3b): execute_model()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_sending = [request_id]
# (3c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm we do not have any memory leaks after req lifecycle.
assert_scheduler_empty(scheduler)
def test_short_prompt_lifecycle():
"""Test lifecycle of a Remote Decode request with short prompt."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# Not enough tokens for full block.
NUM_TOKENS = vllm_config.cache_config.block_size // 2
request = create_request(request_id=1,
max_tokens=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
scheduler.add_request(request)
# STEP (1): Prefill.
# (1a): schedule()
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
# (1b): execute_model()
model_runner_output = create_model_runner_output(reqs=[request])
# (1c): update_from_output()
# Since tokens < block_size, there will be no kv xfer.
# So this should be cleaned up immediately.
_ = scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm we do not have any memory leaks after req lifecycle.
# We need one more call to schedule() to clear data for persistent batch.
_ = scheduler.schedule()
assert_scheduler_empty(scheduler)
def test_prefix_cache_lifecycle():
"""Test that remote decode params still works with a prefix cache hit."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# Prime the KVCache.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 3
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS)
scheduler.add_request(request_normal)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal],
use_eos=True)
scheduler.update_from_output(scheduler_output, model_runner_output)
scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
#####################
# Actual Test: confirm we send all blocks.
# Step (1): Send the KV Transfer.
NUM_EXTERNAL_FULL_BLOCKS -= 1
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote])
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
kv_transfer_params = eco.outputs[0].kv_transfer_params
# Ensure we send all block ids, even if there is a cache hit.
assert (len(
kv_transfer_params["remote_block_ids"]) == NUM_EXTERNAL_FULL_BLOCKS)
# STEP (2): Ensure it is freed.
scheduler_output = scheduler.schedule()
scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_sending = [request_remote.request_id]
scheduler.update_from_output(scheduler_output, model_runner_output)
_ = scheduler.schedule()
assert_scheduler_empty(scheduler)
# SPDX-License-Identifier: Apache-2.0
import copy
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
from vllm.v1.request import FinishReason, RequestStatus
from .utils import (assert_scheduler_empty, create_model_runner_output,
create_request, create_scheduler, create_vllm_config)
def test_basic_lifecycle():
"""Test lifecycle of a remote prefill."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
START_FREE_BLOCK_QUEUE_SIZE = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1):
# (1a): schedule()
scheduler_output = scheduler.schedule()
# Nothing running and empty scheduler output.
assert len(scheduler.running) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 0
assert len(scheduler_output.num_scheduled_tokens) == 0
assert scheduler_output.total_num_scheduled_tokens == 0
# Req waiting for KVs with no computed/scheduled toks ...
assert len(scheduler.waiting) == 1
assert request in scheduler.waiting
assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
assert (request.num_computed_tokens == 0)
# ... but should have (uncached) blocks allocated to it.
block_pool = scheduler.kv_cache_manager.block_pool
assert (block_pool.free_block_queue.num_free_blocks
< START_FREE_BLOCK_QUEUE_SIZE)
assert len(block_pool.cached_block_hash_to_block) == 0
blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
request_id]
for block in blocks:
assert block._block_hash is None
# (1b): forward()
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
# (1c): update_from_output()
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
assert len(engine_core_outputs.outputs) == 0
# STEP (2):
# (2a): schedule(): nothing happens!
scheduler_output = scheduler.schedule()
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 0
# (2b): forward(): request finishes recv.
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_recving = [request_id]
# (2c): update_from_output():
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
assert len(scheduler.waiting) == 1
assert (request_id in scheduler.finished_recving_kv_req_ids)
# STEP (3):
# (3a): schedule(): this should actually schedule.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
# Confirm the block are actually allocated.
num_hashed_blocks = 0
blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
request_id]
for block in blocks:
assert block.ref_cnt == 1
num_hashed_blocks += (1 if block._block_hash is not None else 0)
assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS
# Confirm the rest of the prompt is scheduled in this step.
scheduled_req = scheduler_output.scheduled_new_reqs[0]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id]
num_computed_tokens = scheduled_req.num_computed_tokens
total_prompt_tokens = len(scheduled_req.prompt_token_ids)
assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens)
# (3b): execute_model()
model_runner_output = create_model_runner_output([request])
# (3c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# Step (4): Hit EOS.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output([request], use_eos=True)
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
scheduler.schedule()
outputs = engine_core_outputs.outputs
assert len(outputs) == 1
output = outputs[0]
assert output.finish_reason == FinishReason.STOP
assert_scheduler_empty(scheduler)
def test_interleaved_lifecycle():
"""Test Remote Prefills Work Well With Other Requests."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
request_local_a = create_request(
request_id=2,
num_tokens=NUM_TOKENS,
)
request_local_b = create_request(
request_id=3,
num_tokens=NUM_TOKENS,
)
# STEP 1: Regular request is running.
scheduler.add_request(request_local_a)
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
model_runner_output = create_model_runner_output([request_local_a])
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP 2: Add a local and remote request.
scheduler.add_request(request_local_b)
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
assert len(scheduler_output.scheduled_cached_reqs) == 1
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b])
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP 3: continue running, KVs not arrived yet.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 2
model_runner_output = create_model_runner_output(
reqs=[request_local_a, request_local_b])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 2
# STEP 4: KVs arrive.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 2
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b],
finished_recving=[request_remote.request_id])
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP 5: RECVed KVs are sent to ModelRunner.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(scheduler.waiting) == 0
assert len(scheduler_output.scheduled_new_reqs) == 1
assert len(scheduler_output.scheduled_cached_reqs) == 2
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b, request_remote])
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP 6: Hit EOS and free.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b, request_remote],
use_eos=True,
)
scheduler.update_from_output(scheduler_output, model_runner_output)
scheduler.schedule()
assert_scheduler_empty(scheduler)
def test_no_spurious_prefix_caching():
"""
With P/D, blocks can be allocated but uncomputed for
multiple engine steps. This test confirms that we do
not accidentally have cache hits against uncomputed
blocks.
"""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 and a half full external blocks.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
# Both of these requests have prompts like [1,1,1,1,1, ...]
request_remote = create_request(
request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
use_all_1s_for_prompt_tokens=True,
)
request_local = create_request(
request_id=2,
num_tokens=NUM_TOKENS,
do_remote_prefill=False,
use_all_1s_for_prompt_tokens=True,
)
# Schedule the remote prefill request. This should not
# cause any blocks to be cached.
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
assert len(scheduler.waiting) == 1
# Schedule the local prefill request. This should
# cause blocks to be cached, but separately from
scheduler.add_request(request_local)
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
local_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
request_local.request_id]
remote_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ # noqa: E501
request_remote.request_id]
# Local should have cached blocks (but not all due to preallocate).
num_hashed_blocks = 0
for block in local_blocks:
assert block.ref_cnt == 1
num_hashed_blocks += (1 if block._block_hash is not None else 0)
assert num_hashed_blocks > 0
# Remote blocks should not be cached.
for block in remote_blocks:
assert block.ref_cnt == 1
assert block._block_hash is None
def test_full_block_prompt():
"""Test that we handle a prompt that is the full block size."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1): Initialize a recv.
scheduler_output = scheduler.schedule()
# All blocks should be allocated.
num_blocks = len(scheduler.kv_cache_manager.single_type_manager.
req_to_blocks[request_id])
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
scheduler.update_from_output(scheduler_output, model_runner_output)
# # STEP (2): Recv.
scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_recving = [request_id]
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.waiting) == 1
assert (request_id in scheduler.finished_recving_kv_req_ids)
# # STEP (3): Run as usual.
scheduler_output = scheduler.schedule()
# We need to recompute the final token of the prompt to generate
# the first new token, so we should not have a new block.
num_blocks = len(scheduler.kv_cache_manager.single_type_manager.
req_to_blocks[request_id])
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
NUM_TOKENS - 1)
assert (scheduler_output.num_scheduled_tokens[request_id] == 1)
model_runner_output = create_model_runner_output([request])
scheduler.update_from_output(scheduler_output, model_runner_output)
# # Step (4): Hit EOS.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output([request], use_eos=True)
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
scheduler.schedule()
outputs = engine_core_outputs.outputs
assert len(outputs) == 1
output = outputs[0]
assert output.finish_reason == FinishReason.STOP
assert_scheduler_empty(scheduler)
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlKVTransferParams)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
EOS_TOKEN_ID = 50256
def assert_scheduler_empty(scheduler: Scheduler):
"""Confirm the scheduler is "empty" - i.e. no leaks."""
# Scheduler Metadata.
assert len(scheduler.requests) == 0
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0
assert len(scheduler.finished_recving_kv_req_ids) == 0
assert len(scheduler._cached_reqs_data) == 0
# EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
assert len(
scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
assert len(
scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
assert num_free_blocks == (
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
for block in scheduler.kv_cache_manager.block_pool.blocks:
assert block.ref_cnt == 0
def create_vllm_config(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 64,
block_size: int = 16,
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_num_batched_tokens,
)
model_config = ModelConfig(
model=model,
task="auto",
tokenizer=model,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="float16",
seed=42,
)
# Cache config, optionally force APC
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
enable_prefix_caching=True,
)
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector",
kv_role="kv_both",
)
return VllmConfig(scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"))
def create_scheduler(
vllm_config: VllmConfig,
num_blocks: int = 10000,
) -> Scheduler:
"""Initialize Scheduler For Testing."""
block_size = vllm_config.cache_config.block_size
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
tensors={},
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
],
)
vllm_config.cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
def create_request(
request_id: int,
num_tokens: int = 10,
max_tokens: int = 16,
do_remote_decode: bool = False,
do_remote_prefill: bool = False,
use_all_1s_for_prompt_tokens: bool = False,
num_remote_blocks: int = 3,
) -> Request:
"""Make dummy request for testing."""
if do_remote_decode:
assert not do_remote_prefill
kv_transfer_params = NixlKVTransferParams(do_remote_prefill=False,
do_remote_decode=True)
elif do_remote_prefill:
kv_transfer_params = NixlKVTransferParams(
do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id="my-engine-id",
remote_block_ids=list(range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234)
else:
kv_transfer_params = None
max_tokens = 1 if do_remote_decode else max_tokens
sampling_params = SamplingParams(max_tokens=max_tokens)
if use_all_1s_for_prompt_tokens:
prompt_token_ids = [1] * num_tokens
else:
prompt_token_ids = [i * request_id for i in range(num_tokens)]
req = Request(
request_id=f"id-{request_id}",
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
multi_modal_inputs=None,
multi_modal_placeholders=None,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=0,
)
req.kv_transfer_params = kv_transfer_params
return req
def create_model_runner_output(
reqs: list[Request],
finished_sending: Optional[list[str]] = None,
finished_recving: Optional[list[str]] = None,
use_eos: bool = False,
) -> ModelRunnerOutput:
"""Make dummy model runner output for testing."""
# Make request data.
req_ids = [req.request_id for req in reqs]
req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}
# Make sampled tokens.
sampled_token = EOS_TOKEN_ID if use_eos else 0
sampled_token_ids = [[sampled_token] for _ in req_ids]
# Make output data structure.
return ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
finished_sending=finished_sending,
finished_recving=finished_recving,
)
...@@ -8,6 +8,7 @@ import inspect ...@@ -8,6 +8,7 @@ import inspect
import json import json
import re import re
import textwrap import textwrap
import uuid
import warnings import warnings
from collections import Counter from collections import Counter
from contextlib import contextmanager from contextlib import contextmanager
...@@ -3438,6 +3439,9 @@ class KVTransferConfig: ...@@ -3438,6 +3439,9 @@ class KVTransferConfig:
"""The KV connector for vLLM to transmit KV caches between vLLM instances. """The KV connector for vLLM to transmit KV caches between vLLM instances.
""" """
engine_id: str = str(uuid.uuid4())
"""The engine id for KV transfers."""
kv_buffer_device: Optional[str] = "cuda" kv_buffer_device: Optional[str] = "cuda"
"""The device used by kv connector to buffer the KV cache. """The device used by kv connector to buffer the KV cache.
Currently only support 'cuda'.""" Currently only support 'cuda'."""
...@@ -3448,7 +3452,7 @@ class KVTransferConfig: ...@@ -3448,7 +3452,7 @@ class KVTransferConfig:
kv_role: Optional[KVRole] = None kv_role: Optional[KVRole] = None
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices """Whether this vLLM instance produces, consumes KV cache, or both. Choices
are 'kv_producer', 'kv_consumer', and 'both'.""" are 'kv_producer', 'kv_consumer', and 'kv_both'."""
kv_rank: Optional[int] = None kv_rank: Optional[int] = None
"""The rank of this vLLM instance in the KV cache transfer. Typical value: """The rank of this vLLM instance in the KV cache transfer. Typical value:
......
...@@ -105,3 +105,8 @@ KVConnectorFactory.register_connector( ...@@ -105,3 +105,8 @@ KVConnectorFactory.register_connector(
"LMCacheConnectorV1", "LMCacheConnectorV1",
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector", "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
"LMCacheConnectorV1") "LMCacheConnectorV1")
KVConnectorFactory.register_connector(
"NixlConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",
"NixlConnector")
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorRole) KVConnectorBase_V1, KVConnectorRole, KVTransferParams)
__all__ = [ __all__ = ["KVConnectorRole", "KVConnectorBase_V1", "KVTransferParams"]
"KVConnectorRole",
"KVConnectorBase_V1",
]
...@@ -23,7 +23,7 @@ The class provides the following primitives: ...@@ -23,7 +23,7 @@ The class provides the following primitives:
import enum import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any, Optional
import torch import torch
...@@ -34,6 +34,7 @@ if TYPE_CHECKING: ...@@ -34,6 +34,7 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request from vllm.v1.request import Request
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -47,12 +48,34 @@ class KVConnectorRole(enum.Enum): ...@@ -47,12 +48,34 @@ class KVConnectorRole(enum.Enum):
WORKER = 1 WORKER = 1
class KVTransferParams:
"""
Abstract KVTransferParams used to send KVTransfer
parameters between instances of vLLM.
Specific instances of KVConnector customize this
method for serializing / deserializing msgs sent
via the HTTP protocol.
"""
@staticmethod
def from_raw_dict(
raw_dict: Optional[dict[str,
Any]]) -> Optional["KVTransferParams"]:
return None
@dataclass @dataclass
class KVConnectorMetadata: class KVConnectorMetadata:
"""
Abstract Metadata used to communicate between the
Scheduler KVConnector and Worker KVConnector.
"""
pass pass
class KVConnectorBase_V1(ABC): class KVConnectorBase_V1(ABC):
_KVTransferParams = KVTransferParams
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
logger.warning( logger.warning(
...@@ -66,6 +89,10 @@ class KVConnectorBase_V1(ABC): ...@@ -66,6 +89,10 @@ class KVConnectorBase_V1(ABC):
def role(self) -> KVConnectorRole: def role(self) -> KVConnectorRole:
return self._role return self._role
# ==============================
# Worker-side methods
# ==============================
def bind_connector_metadata( def bind_connector_metadata(
self, connector_metadata: KVConnectorMetadata) -> None: self, connector_metadata: KVConnectorMetadata) -> None:
"""Set the connector metadata from the scheduler. """Set the connector metadata from the scheduler.
...@@ -97,9 +124,15 @@ class KVConnectorBase_V1(ABC): ...@@ -97,9 +124,15 @@ class KVConnectorBase_V1(ABC):
""" """
return self._connector_metadata return self._connector_metadata
# ============================== def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
# Worker-side methods """
# ============================== Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL).
Args: kv_caches:
dictionary of layer names, kv cache
"""
return
@abstractmethod @abstractmethod
def start_load_kv(self, forward_context: "ForwardContext", def start_load_kv(self, forward_context: "ForwardContext",
...@@ -162,15 +195,37 @@ class KVConnectorBase_V1(ABC): ...@@ -162,15 +195,37 @@ class KVConnectorBase_V1(ABC):
""" """
pass pass
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous (recving, sending).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
return None, None
# ============================== # ==============================
# Scheduler-side methods # Scheduler-side methods
# ============================== # ==============================
def set_kv_transfer_params(self, request: "Request"):
"""Parse raw KV Transfer params."""
assert request.kv_transfer_params is None
kv_transfer_params = self._KVTransferParams.from_raw_dict(
request.raw_kv_transfer_params)
request.kv_transfer_params = kv_transfer_params
@abstractmethod @abstractmethod
def get_num_new_matched_tokens( def get_num_new_matched_tokens(
self, self,
request: "Request", request: "Request",
num_computed_tokens: int, num_computed_tokens: int,
) -> int: ) -> tuple[int, bool]:
""" """
Get number of new tokens that can be loaded from the Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens. external KV cache beyond the num_computed_tokens.
...@@ -181,13 +236,16 @@ class KVConnectorBase_V1(ABC): ...@@ -181,13 +236,16 @@ class KVConnectorBase_V1(ABC):
computed tokens for this request computed tokens for this request
Returns: Returns:
the number of tokens that can be loaded from the * the number of tokens that can be loaded from the
external KV cache beyond what is already computed. external KV cache beyond what is already computed.
* true if external KV cache tokens will be loaded
asynchronously (between scheduler steps).
""" """
pass pass
@abstractmethod @abstractmethod
def update_state_after_alloc(self, request: "Request", def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int): num_external_tokens: int):
""" """
Update KVConnector state after block allocation. Update KVConnector state after block allocation.
...@@ -207,3 +265,20 @@ class KVConnectorBase_V1(ABC): ...@@ -207,3 +265,20 @@ class KVConnectorBase_V1(ABC):
scheduler_output (SchedulerOutput): the scheduler output object. scheduler_output (SchedulerOutput): the scheduler output object.
""" """
pass pass
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
return False, None
...@@ -13,6 +13,7 @@ from vllm.v1.core.sched.output import SchedulerOutput ...@@ -13,6 +13,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request from vllm.v1.request import Request
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -92,7 +93,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1): ...@@ -92,7 +93,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
self, self,
request: "Request", request: "Request",
num_computed_tokens: int, num_computed_tokens: int,
) -> int: ) -> tuple[int, bool]:
""" """
Get number of new tokens that can be loaded from the Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens. external KV cache beyond the num_computed_tokens.
...@@ -107,9 +108,10 @@ class LMCacheConnectorV1(KVConnectorBase_V1): ...@@ -107,9 +108,10 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
external KV cache beyond what is already computed. external KV cache beyond what is already computed.
""" """
return self._lmcache_engine.get_num_new_matched_tokens( return self._lmcache_engine.get_num_new_matched_tokens(
request, num_computed_tokens) request, num_computed_tokens), False
def update_state_after_alloc(self, request: "Request", def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int): num_external_tokens: int):
""" """
Update KVConnector state after block allocation. Update KVConnector state after block allocation.
......
This diff is collapsed.
...@@ -17,6 +17,7 @@ from vllm.v1.core.sched.output import SchedulerOutput ...@@ -17,6 +17,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request from vllm.v1.request import Request
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -132,8 +133,7 @@ class SharedStorageConnector(KVConnectorBase_V1): ...@@ -132,8 +133,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
# Get the metadata # Get the metadata
metadata: KVConnectorMetadata = \ metadata: KVConnectorMetadata = self._get_connector_metadata()
self._get_connector_metadata()
assert isinstance(metadata, SharedStorageConnectorMetadata) assert isinstance(metadata, SharedStorageConnectorMetadata)
if metadata is None: if metadata is None:
...@@ -225,7 +225,7 @@ class SharedStorageConnector(KVConnectorBase_V1): ...@@ -225,7 +225,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
self, self,
request: "Request", request: "Request",
num_computed_tokens: int, num_computed_tokens: int,
) -> int: ) -> tuple[int, bool]:
""" """
Get number of new tokens that can be loaded from the Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens. external KV cache beyond the num_computed_tokens.
...@@ -239,7 +239,6 @@ class SharedStorageConnector(KVConnectorBase_V1): ...@@ -239,7 +239,6 @@ class SharedStorageConnector(KVConnectorBase_V1):
the number of tokens that can be loaded from the the number of tokens that can be loaded from the
external KV cache beyond what is already computed. external KV cache beyond what is already computed.
""" """
# NOTE: in this debug implementation, we assume that the prompt is # NOTE: in this debug implementation, we assume that the prompt is
# cached_prompt + newly_generated_single_token # cached_prompt + newly_generated_single_token
# Therefore, we use prompt_token_ids[:-1] to determine the folder name # Therefore, we use prompt_token_ids[:-1] to determine the folder name
...@@ -248,7 +247,7 @@ class SharedStorageConnector(KVConnectorBase_V1): ...@@ -248,7 +247,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# with the block granularity. And it expects the returned blocks and # with the block granularity. And it expects the returned blocks and
# num_computed_tokens to also be aligned with the block granularity. # num_computed_tokens to also be aligned with the block granularity.
if not self._found_match_for_request(request): if not self._found_match_for_request(request):
return 0 return 0, False
logger.info("External Cache Hit!") logger.info("External Cache Hit!")
...@@ -257,9 +256,10 @@ class SharedStorageConnector(KVConnectorBase_V1): ...@@ -257,9 +256,10 @@ class SharedStorageConnector(KVConnectorBase_V1):
num_tokens_to_check = align_to_block_size( num_tokens_to_check = align_to_block_size(
len(request.prompt_token_ids) - 1, self._block_size) len(request.prompt_token_ids) - 1, self._block_size)
return num_tokens_to_check - num_computed_tokens return num_tokens_to_check - num_computed_tokens, False
def update_state_after_alloc(self, request: "Request", def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int): num_external_tokens: int):
""" """
Update KVConnector state after block allocation. Update KVConnector state after block allocation.
......
...@@ -403,6 +403,9 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -403,6 +403,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
"access by 3rd parties, and long enough to be " "access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding " "unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit). Not supported by vLLM engine V0.")) "to 256 bit). Not supported by vLLM engine V0."))
kv_transfer_params: Optional[dict[str, Any]] = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.")
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
...@@ -540,7 +543,9 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -540,7 +543,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
output_kind=RequestOutputKind.DELTA if self.stream \ output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY, else RequestOutputKind.FINAL_ONLY,
guided_decoding=guided_decoding, guided_decoding=guided_decoding,
logit_bias=self.logit_bias) logit_bias=self.logit_bias,
extra_args=({"kv_transfer_params": self.kv_transfer_params}
if self.kv_transfer_params else None))
def _get_guided_json_from_tool( def _get_guided_json_from_tool(
self) -> Optional[Union[str, dict, BaseModel]]: self) -> Optional[Union[str, dict, BaseModel]]:
...@@ -848,6 +853,10 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -848,6 +853,10 @@ class CompletionRequest(OpenAIBaseModel):
" as strings of the form 'token_id:{token_id}' so that tokens " " as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified.")) "that are not JSON-encodable can be identified."))
kv_transfer_params: Optional[dict[str, Any]] = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.")
# doc: end-completion-extra-params # doc: end-completion-extra-params
# Default sampling parameters for completion requests # Default sampling parameters for completion requests
...@@ -973,7 +982,9 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -973,7 +982,9 @@ class CompletionRequest(OpenAIBaseModel):
else RequestOutputKind.FINAL_ONLY, else RequestOutputKind.FINAL_ONLY,
guided_decoding=guided_decoding, guided_decoding=guided_decoding,
logit_bias=self.logit_bias, logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids) allowed_token_ids=self.allowed_token_ids,
extra_args=({"kv_transfer_params": self.kv_transfer_params}
if self.kv_transfer_params else None))
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
...@@ -1223,6 +1234,8 @@ class CompletionResponse(OpenAIBaseModel): ...@@ -1223,6 +1234,8 @@ class CompletionResponse(OpenAIBaseModel):
model: str model: str
choices: list[CompletionResponseChoice] choices: list[CompletionResponseChoice]
usage: UsageInfo usage: UsageInfo
kv_transfer_params: Optional[dict[str, Any]] = Field(
default=None, description="KVTransfer parameters.")
class CompletionResponseStreamChoice(OpenAIBaseModel): class CompletionResponseStreamChoice(OpenAIBaseModel):
...@@ -1412,6 +1425,8 @@ class ChatCompletionResponse(OpenAIBaseModel): ...@@ -1412,6 +1425,8 @@ class ChatCompletionResponse(OpenAIBaseModel):
choices: list[ChatCompletionResponseChoice] choices: list[ChatCompletionResponseChoice]
usage: UsageInfo usage: UsageInfo
prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
kv_transfer_params: Optional[dict[str, Any]] = Field(
default=None, description="KVTransfer parameters.")
class DeltaMessage(OpenAIBaseModel): class DeltaMessage(OpenAIBaseModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment