Commit 006693ed authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.11.2' into v0.11.2-ori

parents 4b51e6f1 275de341
......@@ -530,7 +530,7 @@ spec:
name: accelerators-thanos-querier-datasource
# Multiply by 100 so we can read it as a percentage without setting a unit (avoids CUE unit conflicts)
query: >
100 * avg(vllm:gpu_cache_usage_perc)
100 * avg(vllm:kv_cache_usage_perc)
"18":
kind: Panel
......
......@@ -98,7 +98,7 @@ spec:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
query: avg(vllm:gpu_cache_usage_perc{namespace="$NS",service="$SVC"}) or vector(0)
query: avg(vllm:kv_cache_usage_perc{namespace="$NS",service="$SVC"}) or vector(0)
minStep: "15s"
core_running_ts:
......@@ -168,7 +168,7 @@ spec:
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
# multiply by 100 to present percentage; omit format.unit to avoid schema conflicts
query: (avg(vllm:gpu_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0)
query: (avg(vllm:kv_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0)
minStep: "15s"
core_kv_usage_pct_ts:
......@@ -187,7 +187,7 @@ spec:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
query: (avg by (service) (vllm:gpu_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0)
query: (avg by (service) (vllm:kv_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0)
minStep: "15s"
# --- Per-Pod breakdowns (works on Simulator & Real) ---
......@@ -246,7 +246,7 @@ spec:
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
# if your exporter labels kv metric with pod (the sim does), this works; otherwise it will just return empty
query: (avg by (pod) (vllm:gpu_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0)
query: (avg by (pod) (vllm:kv_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0)
minStep: "15s"
# --- Real vLLM only (zeros on simulator) ---
......
# Disaggregated Encoder
These example scripts that demonstrate the disaggregated encoder (EPD) features of vLLM.
For a detailed explanation of the EPD features, please refer to the [Disaggregated Encoder Feature Documentation](../../../docs/features/disagg_encoder.md).
## Files
- `disagg_epd_proxy.py` - Proxy script that demonstrates the XeYpZd setup (X encode instances, Y prefill instances, Z decode instances). Currently stable for the 1e1p1d configuration.
- `disagg_1e1p1d_example.sh` - Sets up the 1e1p1d configuration, runs the VisionArena benchmark, and processes a single request with a local image.
- `disagg_1e1pd_example.sh` - Sets up the 1e1pd configuration, runs the VisionArena benchmark, and processes a single request with a local image.
### Custom Configuration
```bash
# Use specific GPUs
GPU_E=0 GPU_PD=1 GPU_P=1 GPU_D=2 bash disagg_1e1p1d_example.sh
# Use specific ports
ENDPOINT_PORT=10001 bash disagg_1e1p1d_example.sh
# Use specific model
MODEL="Qwen/Qwen2.5-VL-3B-Instruct" bash disagg_1e1p1d_example.sh
# Use specific storage path
EC_SHARED_STORAGE_PATH="/tmp/my_ec_cache" bash disagg_1e1p1d_example.sh
```
## Encoder Instances
Encoder engines should be launched with the following flags:
- `--enforce-eager` **(required)** – The current EPD implementation is only compatible with encoder instances running in this mode.
- `--no-enable-prefix-caching` **(required)** – Encoder instances do not consume KV cache; prefix caching is disabled to avoid conflicts with other features.
- `--max-num-batched-tokens=<large value>` **(default: 2048)** – This flag controls the token scheduling budget per decoding step and is irrelevant to encoder-only instances. **Set it to a very high value (effectively unlimited) to bypass scheduler limitations.** The actual token budget is managed by the encoder cache manager.
## Local media inputs
To support local image inputs (from your ```MEDIA_PATH``` directory), add the following flag to the encoder instance:
```bash
--allowed-local-media-path $MEDIA_PATH
```
The vllm instances and `disagg_encoder_proxy` supports local URIs with ```{"url": "file://'"$MEDIA_PATH_FILENAME"'}``` as multimodal inputs. Each URI is passed unchanged from the `disagg_encoder_proxy` to the encoder instance so that the encoder can load the media locally.
## EC connector and KV transfer
The `ECSharedStorageConnector` is used to store the encoder cache on local disk and facilitate transfer. To enable the encoder disaggregation feature, add the following configuration:
```bash
# Add to encoder instance:
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_role": "ec_producer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}'
# Add to prefill/prefill+decode instance:
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_role": "ec_consumer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}'
```
`$EC_SHARED_STORAGE_PATH` is the path where the EC connector temporarily stores the cache.
If you enable prefill instance (`--prefill-servers-urls` not disabled), you will need --kv-transfer-config to facilitate the PD disaggregation. Currently, we use the `NixlConnector` for this purpose. Refer to `tests/v1/kv_connector/nixl_integration` for more example codes on PD disaggregation with Nixl.
```bash
# Add to prefill instance:
--kv-transfer-config '{
"kv_connector": "NixlConnector",
"kv_role": "kv_producer"
}'
# Add to decode instance:
--kv-transfer-config '{
"kv_connector": "NixlConnector",
"kv_role": "kv_consumer"
}'
```
## Proxy Instance Flags (`disagg_epd_proxy.py`)
| Flag | Description |
|------|-------------|
| `--encode-servers-urls` | Comma-separated list of encoder endpoints. Every multimodal item extracted from the request is fanned out to one of these URLs in a round-robin fashion. |
| `--prefill-servers-urls` | Comma-separated list of prefill endpoints. Set to `disable`, `none`, or `""` to skip the dedicated prefill phase and run E+PD (encoder + combined prefill/decode). |
| `--decode-servers-urls` | Comma-separated list of decode endpoints. Non-stream and stream paths both round-robin over this list. |
| `--host`, `--port` | Bind address for the proxy itself (defaults: `0.0.0.0:8000`). |
Example usage:
For E + PD setup:
```bash
$ python disagg_encoder_proxy.py \
--encode-servers-urls "http://e1:8001,http://e2:8002" \
--prefill-servers-urls "disable" \
--decode-servers-urls "http://pd1:8003,http://pd2:8004"
```
For E + P + D setup:
```bash
$ python disagg_encoder_proxy.py \
--encode-servers-urls "http://e1:8001,http://e2:8001" \
--prefill-servers-urls "http://p1:8003,http://p2:8004" \
--decode-servers-urls "http://d1:8005,http://d2:8006"
```
#!/bin/bash
set -euo pipefail
declare -a PIDS=()
###############################################################################
# Configuration -- override via env before running
###############################################################################
MODEL="${MODEL:-Qwen/Qwen2.5-VL-3B-Instruct}"
LOG_PATH="${LOG_PATH:-./logs}"
mkdir -p $LOG_PATH
ENCODE_PORT="${ENCODE_PORT:-19534}"
PREFILL_PORT="${PREFILL_PORT:-19535}"
DECODE_PORT="${DECODE_PORT:-19536}"
PROXY_PORT="${PROXY_PORT:-10001}"
GPU_E="${GPU_E:-2}"
GPU_P="${GPU_P:-2}"
GPU_D="${GPU_D:-3}"
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}"
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout
NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark
export UCX_TLS=all
export UCX_NET_DEVICES=all
###############################################################################
# Helpers
###############################################################################
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
START_TIME=$(date +"%Y%m%d_%H%M%S")
ENC_LOG=$LOG_PATH/encoder_${START_TIME}.log
P_LOG=$LOG_PATH/p_${START_TIME}.log
D_LOG=$LOG_PATH/d_${START_TIME}.log
PROXY_LOG=$LOG_PATH/proxy_${START_TIME}.log
wait_for_server() {
local port=$1
timeout "$TIMEOUT_SECONDS" bash -c "
until curl -s localhost:$port/v1/chat/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Cleanup function
cleanup() {
echo "Stopping everything…"
trap - INT TERM USR1 # prevent re-entrancy
# Kill all tracked PIDs
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Killing process $pid"
kill "$pid" 2>/dev/null
fi
done
# Wait a moment for graceful shutdown
sleep 2
# Force kill any remaining processes
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Force killing process $pid"
kill -9 "$pid" 2>/dev/null
fi
done
# Kill the entire process group as backup
kill -- -$$ 2>/dev/null
echo "All processes stopped."
exit 0
}
trap cleanup INT
trap cleanup USR1
trap cleanup TERM
# clear previous cache
echo "remove previous ec cache folder"
rm -rf $EC_SHARED_STORAGE_PATH
echo "make ec cache folder"
mkdir -p $EC_SHARED_STORAGE_PATH
###############################################################################
# Encoder worker
###############################################################################
CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
--gpu-memory-utilization 0.01 \
--port "$ENCODE_PORT" \
--enforce-eager \
--enable-request-id-headers \
--no-enable-prefix-caching \
--max-num-batched-tokens 114688 \
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_role": "ec_producer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}' \
>"${ENC_LOG}" 2>&1 &
PIDS+=($!)
###############################################################################
# Prefill worker
###############################################################################
CUDA_VISIBLE_DEVICES="$GPU_P" \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \
vllm serve "$MODEL" \
--gpu-memory-utilization 0.7 \
--port "$PREFILL_PORT" \
--enforce-eager \
--enable-request-id-headers \
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_role": "ec_consumer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}' \
--kv-transfer-config '{
"kv_connector": "NixlConnector",
"kv_role": "kv_producer"
}' \
>"${P_LOG}" 2>&1 &
PIDS+=($!)
###############################################################################
# Decode worker
###############################################################################
CUDA_VISIBLE_DEVICES="$GPU_D" \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=6000 \
vllm serve "$MODEL" \
--gpu-memory-utilization 0.7 \
--port "$DECODE_PORT" \
--enforce-eager \
--enable-request-id-headers \
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--kv-transfer-config '{
"kv_connector": "NixlConnector",
"kv_role": "kv_consumer"
}' \
>"${D_LOG}" 2>&1 &
PIDS+=($!)
# Wait for workers
wait_for_server $ENCODE_PORT
wait_for_server $PREFILL_PORT
wait_for_server $DECODE_PORT
###############################################################################
# Proxy
###############################################################################
python disagg_epd_proxy.py \
--host "0.0.0.0" \
--port "$PROXY_PORT" \
--encode-servers-urls "http://localhost:$ENCODE_PORT" \
--prefill-servers-urls "http://localhost:$PREFILL_PORT" \
--decode-servers-urls "http://localhost:$DECODE_PORT" \
>"${PROXY_LOG}" 2>&1 &
PIDS+=($!)
wait_for_server $PROXY_PORT
echo "All services are up!"
###############################################################################
# Benchmark
###############################################################################
echo "Running benchmark (stream)..."
vllm bench serve \
--model $MODEL \
--backend openai-chat \
--endpoint /v1/chat/completions \
--dataset-name hf \
--dataset-path lmarena-ai/VisionArena-Chat \
--seed 0 \
--num-prompts $NUM_PROMPTS \
--port $PROXY_PORT
PIDS+=($!)
###############################################################################
# Single request with local image
###############################################################################
echo "Running single request with local image (non-stream)..."
curl http://127.0.0.1:${PROXY_PORT}/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "'${MODEL}'",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [
{"type": "image_url", "image_url": {"url": "file://'"${GIT_ROOT}"'/tests/v1/ec_connector/integration/hato.jpg"}},
{"type": "text", "text": "What is in this image?"}
]}
]
}'
# cleanup
echo "cleanup..."
cleanup
\ No newline at end of file
#!/bin/bash
set -euo pipefail
declare -a PIDS=()
###############################################################################
# Configuration -- override via env before running
###############################################################################
MODEL="${MODEL:-Qwen/Qwen2.5-VL-3B-Instruct}"
LOG_PATH="${LOG_PATH:-./logs}"
mkdir -p $LOG_PATH
ENCODE_PORT="${ENCODE_PORT:-19534}"
PREFILL_DECODE_PORT="${PREFILL_DECODE_PORT:-19535}"
PROXY_PORT="${PROXY_PORT:-10001}"
GPU_E="${GPU_E:-0}"
GPU_PD="${GPU_PD:-1}"
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}"
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout
NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark
###############################################################################
# Helpers
###############################################################################
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
START_TIME=$(date +"%Y%m%d_%H%M%S")
ENC_LOG=$LOG_PATH/encoder_${START_TIME}.log
PD_LOG=$LOG_PATH/pd_${START_TIME}.log
PROXY_LOG=$LOG_PATH/proxy_${START_TIME}.log
wait_for_server() {
local port=$1
timeout "$TIMEOUT_SECONDS" bash -c "
until curl -s localhost:$port/v1/chat/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Cleanup function
cleanup() {
echo "Stopping everything…"
trap - INT TERM USR1 # prevent re-entrancy
# Kill all tracked PIDs
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Killing process $pid"
kill "$pid" 2>/dev/null
fi
done
# Wait a moment for graceful shutdown
sleep 2
# Force kill any remaining processes
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Force killing process $pid"
kill -9 "$pid" 2>/dev/null
fi
done
# Kill the entire process group as backup
kill -- -$$ 2>/dev/null
echo "All processes stopped."
exit 0
}
trap cleanup INT
trap cleanup USR1
trap cleanup TERM
# clear previous cache
echo "remove previous ec cache folder"
rm -rf $EC_SHARED_STORAGE_PATH
echo "make ec cache folder"
mkdir -p $EC_SHARED_STORAGE_PATH
###############################################################################
# Encoder worker
###############################################################################
CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
--gpu-memory-utilization 0.01 \
--port "$ENCODE_PORT" \
--enforce-eager \
--enable-request-id-headers \
--no-enable-prefix-caching \
--max-num-batched-tokens 114688 \
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_role": "ec_producer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}' \
>"${ENC_LOG}" 2>&1 &
PIDS+=($!)
###############################################################################
# Prefill+Decode worker
###############################################################################
CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \
--gpu-memory-utilization 0.7 \
--port "$PREFILL_DECODE_PORT" \
--enforce-eager \
--enable-request-id-headers \
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_role": "ec_consumer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}' \
>"${PD_LOG}" 2>&1 &
PIDS+=($!)
# Wait for workers
wait_for_server $ENCODE_PORT
wait_for_server $PREFILL_DECODE_PORT
###############################################################################
# Proxy
###############################################################################
python disagg_epd_proxy.py \
--host "0.0.0.0" \
--port "$PROXY_PORT" \
--encode-servers-urls "http://localhost:$ENCODE_PORT" \
--prefill-servers-urls "disable" \
--decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT" \
>"${PROXY_LOG}" 2>&1 &
PIDS+=($!)
wait_for_server $PROXY_PORT
echo "All services are up!"
###############################################################################
# Benchmark
###############################################################################
echo "Running benchmark (stream)..."
vllm bench serve \
--model $MODEL \
--backend openai-chat \
--endpoint /v1/chat/completions \
--dataset-name hf \
--dataset-path lmarena-ai/VisionArena-Chat \
--seed 0 \
--num-prompts $NUM_PROMPTS \
--port $PROXY_PORT
PIDS+=($!)
###############################################################################
# Single request with local image
###############################################################################
echo "Running single request with local image (non-stream)..."
curl http://127.0.0.1:${PROXY_PORT}/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "'${MODEL}'",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [
{"type": "image_url", "image_url": {"url": "file://'"${GIT_ROOT}"'/tests/v1/ec_connector/integration/hato.jpg"}},
{"type": "text", "text": "What is in this image?"}
]}
]
}'
# cleanup
echo "cleanup..."
cleanup
\ No newline at end of file
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
disagg_encoder_proxy.py
Proxy that routes OpenAI-compatible “/v1/chat/completions” requests to two
clusters:
• encode (multimodal feature extraction)
• decode (language-model inference)
For MM input we:
1. Extract *every* image/audio item.
2. Fire N concurrent requests to the encoder cluster
(one request per item, with **all text removed**).
3. Wait for all of them to succeed.
4. Forward the *original* request to a decode server.
"""
from __future__ import annotations
import argparse
import asyncio
import logging
import os
import random
import uuid
from collections.abc import AsyncIterator
import aiohttp
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
###############################################################################
# FastAPI app & global state
###############################################################################
logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s %(levelname)s: %(message)s"
)
logger = logging.getLogger("proxy")
app = FastAPI()
encode_session: aiohttp.ClientSession | None = None
prefill_session: aiohttp.ClientSession | None = None
decode_session: aiohttp.ClientSession | None = None
###############################################################################
# Utils
###############################################################################
MM_TYPES = {"image_url", "audio_url", "input_audio"}
def extract_mm_items(request_data: dict) -> list[dict]:
"""
Return *all* image/audio items that appear anywhere in `messages`.
Each returned dict looks like:
{ "type": "image_url", "image_url": {...} }
"""
items: list[dict] = []
for msg in request_data.get("messages", []):
content = msg.get("content")
if not isinstance(content, list):
continue
for item in content:
if item.get("type") in MM_TYPES:
items.append(item)
return items
async def fanout_encoder_primer(
orig_request: dict,
e_urls: list[str],
req_id: str,
) -> None:
"""
1. Build one request *per MM item* with all text removed.
2. Send them concurrently to the encode cluster.
3. Raise if any of them fails.
"""
logger.info("[%s] Processing multimodal items...", req_id)
mm_items = extract_mm_items(orig_request)
if not mm_items:
logger.info("[%s] No multimodal items, skipping encoder", req_id)
return # nothing to do
logger.info("[%s] got %d multimodal items...", req_id, len(mm_items))
tasks = []
# Round-robin over encode servers to distribute load a bit
url_cycle = (e_urls[i % len(e_urls)] for i in range(len(mm_items)))
for idx, (item, target_url) in enumerate(zip(mm_items, url_cycle)):
# Derive a *child* request id: <parent>:<index>:<random-short>
child_req_id = f"{req_id}:{idx}:{uuid.uuid4().hex[:6]}"
headers = {"x-request-id": child_req_id}
encoder_req = {
# You *may* need to keep additional fields
"model": orig_request.get("model"),
"messages": [
{"role": "user", "content": [item]},
],
# Only need 1 token so the server actually runs the encoder path
"max_tokens": 1,
"stream": False,
}
tasks.append(
encode_session.post(
f"{target_url}/v1/chat/completions",
json=encoder_req,
headers=headers,
)
)
results = await asyncio.gather(*tasks, return_exceptions=True)
# Fail fast if any sub-request failed
for idx, r in enumerate(results):
if isinstance(r, Exception):
logger.error(
"[%s] Encoder request #%d raised exception: %s",
req_id,
idx,
r,
exc_info=r,
)
raise HTTPException(
status_code=502, detail=f"Encoder request failed: {str(r)}"
)
if r.status != 200:
try:
detail = await r.text()
except Exception:
detail = "<unable to read body>"
logger.error(
"[%s] Encoder request #%d returned status %s: %s",
req_id,
idx,
r.status,
detail,
)
raise HTTPException(
status_code=r.status,
detail=f"Encoder request failed: {detail}",
)
logger.info(
"[%s] All %d encoder requests completed successfully", req_id, len(mm_items)
)
async def maybe_prefill(
req_data: dict,
p_url: str,
req_id: str,
) -> dict:
"""
- Do prefill-only task if p_url exist;
- Return modified request data with kv transfer params (for nixl connector)
- Else, skip and return the original request data for decode
"""
if p_url:
logger.info("[%s] Processing through prefill: %s", req_id, p_url)
prefill_response = await process_prefill_stage(req_data, p_url, req_id)
# for nixl connector to facilitate kv transfer...
prefill_response_json = await prefill_response.json()
kv_transfer_params = prefill_response_json.get("kv_transfer_params", {})
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params
return req_data
else:
return req_data
async def process_prefill_stage(
req_data: dict,
p_url: str,
req_id: str,
) -> dict:
"""Process request through Prefill stage and return kv_transfer_params"""
logger.info("[%s] Sending prefill request to: %s", req_id, p_url)
prefill_request = req_data.copy()
prefill_request["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
}
prefill_request["stream"] = False
prefill_request["max_tokens"] = 1
if "max_completion_tokens" in prefill_request:
prefill_request["max_completion_tokens"] = 1
if "stream_options" in prefill_request:
del prefill_request["stream_options"]
headers = {"x-request-id": req_id}
try:
prefill_response = await prefill_session.post(
f"{p_url}/v1/chat/completions", json=prefill_request, headers=headers
)
prefill_response.raise_for_status()
if prefill_response.status != 200:
error_text = await prefill_response.text()
logger.error(
"[%s] Prefill request failed with status %d: %s",
req_id,
prefill_response.status,
error_text,
)
raise HTTPException(
status_code=prefill_response.status,
detail={"error": "Prefill request failed", "message": error_text},
)
logger.info("[%s] Prefill request completed successfully", req_id)
return prefill_response
except Exception as e:
logger.error("Prefill processing failed: %s", str(e))
raise HTTPException(
status_code=500,
detail={"error": "Prefill processing error", "message": str(e)},
) from e
###############################################################################
# Middleware for request/response logging
###############################################################################
@app.middleware("http")
async def log_requests(request: Request, call_next):
"""Middleware to log all incoming requests and responses"""
req_id = request.headers.get("x-request-id", str(uuid.uuid4()))
# Log incoming request
logger.info(
">>> [%s] %s %s from %s",
req_id,
request.method,
request.url.path,
request.client.host if request.client else "unknown",
)
try:
# Process request
response = await call_next(request)
# Log response
logger.info(
"<<< [%s] %s %s completed with status %d",
req_id,
request.method,
request.url.path,
response.status_code,
)
return response
except Exception as e:
# Log errors
logger.exception(
"!!! [%s] %s %s failed with error: %s",
req_id,
request.method,
request.url.path,
str(e),
)
raise
###############################################################################
# FastAPI lifecycle
###############################################################################
@app.on_event("startup")
async def on_startup() -> None:
global encode_session, prefill_session, decode_session
timeout = aiohttp.ClientTimeout(total=100_000)
connector = aiohttp.TCPConnector(limit=0, force_close=False)
encode_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
if app.state.p_urls:
# only setup if prefill instance(s) exist
prefill_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
decode_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
@app.on_event("shutdown")
async def on_shutdown() -> None:
global encode_session, prefill_session, decode_session
if encode_session:
await encode_session.close()
if prefill_session:
await prefill_session.close()
if decode_session:
await decode_session.close()
###############################################################################
# Core forwarding
###############################################################################
async def forward_non_stream(
req_data: dict, req_id: str, e_urls: list[str], p_url: str, d_url: str
) -> dict:
try:
# Step 1: Process through Encoder instance (if has MM input)
await fanout_encoder_primer(req_data, e_urls, req_id)
# Step 2: Process through Prefill instance
req_data = await maybe_prefill(req_data, p_url, req_id)
# Step 3: Process through Decode instance
logger.info("[%s] Forwarding to decode: %s", req_id, d_url)
headers = {"x-request-id": req_id}
# Non-streaming response
async with decode_session.post(
f"{d_url}/v1/chat/completions", json=req_data, headers=headers
) as resp:
resp.raise_for_status()
return await resp.json()
except HTTPException:
raise
except Exception as e:
logger.exception("[%s] Error in forward_non_stream: %s", req_id, str(e))
raise HTTPException(status_code=500, detail=f"Proxy error: {str(e)}") from e
async def forward_stream(
req_data: dict, req_id: str, e_urls: list[str], p_url: str, d_url: str
) -> AsyncIterator[str]:
try:
# Step 1: Process through Encoder instance (if has MM input)
await fanout_encoder_primer(req_data, e_urls, req_id)
# Step 2: Process through Prefill instance
req_data = await maybe_prefill(req_data, p_url, req_id)
# Step 3: Process through Decode instance
logger.info("[%s] Starting streaming from decode: %s", req_id, d_url)
headers = {"x-request-id": req_id}
# Streaming response
async with decode_session.post(
f"{d_url}/v1/chat/completions",
json=req_data,
headers=headers,
) as resp:
resp.raise_for_status()
async for chunk in resp.content.iter_chunked(1024):
if chunk:
yield chunk.decode("utf-8", errors="ignore")
logger.info("[%s] Streaming completed", req_id)
except HTTPException:
logger.exception("[%s] HTTPException in forward_stream", req_id)
raise
except Exception as e:
logger.exception("[%s] Error in forward_stream: %s", req_id, str(e))
raise HTTPException(
status_code=500, detail=f"Proxy streaming error: {str(e)}"
) from e
###############################################################################
# Public routes
###############################################################################
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
try:
req_data = await request.json()
req_id = request.headers.get("x-request-id", str(uuid.uuid4()))
e_urls = app.state.e_urls # we want the full list for fan-out
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
d_url = random.choice(app.state.d_urls)
is_streaming = req_data.get("stream", False)
if is_streaming:
return StreamingResponse(
forward_stream(req_data, req_id, e_urls, p_url, d_url),
media_type="text/event-stream",
)
result = await forward_non_stream(req_data, req_id, e_urls, p_url, d_url)
return JSONResponse(content=result)
except HTTPException:
raise
except Exception as e:
logger.exception("Error in chat_completions endpoint: %s", str(e))
raise HTTPException(
status_code=500, detail=f"Request processing error: {str(e)}"
) from e
@app.get("/v1/models")
async def list_models():
async with decode_session.get(f"{app.state.d_urls[0]}/v1/models") as resp:
resp.raise_for_status()
return await resp.json()
@app.get("/health")
async def health_check():
async def healthy(urls):
if not urls:
return "empty"
for u in urls:
try:
async with encode_session.get(f"{u}/health") as resp:
resp.raise_for_status()
except Exception:
return "unhealthy"
return "healthy"
e_status, p_status, d_status = await asyncio.gather(
healthy(app.state.e_urls), healthy(app.state.p_urls), healthy(app.state.d_urls)
)
overall_healthy = all(
status != "unhealthy" for status in (e_status, p_status, d_status)
)
status_code = 200 if overall_healthy else 503
return JSONResponse(
{
"proxy": "healthy",
"encode_cluster": e_status,
"prefill_cluster": p_status,
"decode_cluster": d_status,
},
status_code=status_code,
)
###############################################################################
# Simple profiler fan-out (unchanged except for sessions)
###############################################################################
async def _post_if_available(
session: aiohttp.ClientSession,
url: str,
payload: dict,
headers: dict,
) -> dict | None:
"""
POST `payload` to `url`.
Returns
-------
• The decoded JSON body on success (2xx)
• None if the endpoint does not exist (404)
• Raises for anything else.
"""
try:
resp = await session.post(url, json=payload, headers=headers)
if resp.status == 404: # profiling disabled on that server
logger.warning("Profiling endpoint missing on %s", url)
return None
resp.raise_for_status()
return await resp.json(content_type=None)
except aiohttp.ClientResponseError as exc:
# Pass 404 through the branch above, re-raise everything else
if exc.status == 404:
logger.warning("Profiling endpoint missing on %s", url)
return None
raise
except Exception:
# Network errors etc.: propagate
raise
async def _profile_cmd(cmd: str, payload: dict, e_url: str, p_url: str, d_url: str):
"""
Fire & forget to both clusters, tolerate 404.
"""
headers = {"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY', '')}"}
encode_task = _post_if_available(
encode_session, f"{e_url}/{cmd}_profile", payload, headers
)
prefill_task = (
_post_if_available(prefill_session, f"{p_url}/{cmd}_profile", payload, headers)
if p_url is not None
else asyncio.sleep(0)
)
decode_task = _post_if_available(
decode_session, f"{d_url}/{cmd}_profile", payload, headers
)
encode_res, prefill_res, decode_res = await asyncio.gather(
encode_task, prefill_task, decode_task
)
# If *all* clusters said “I don’t have that route”, surface an error
if encode_res is prefill_res is decode_res is None:
raise HTTPException(
status_code=503,
detail="Profiling endpoints are disabled on all clusters",
)
return {
"encode": encode_res, # may be None
"prefill": prefill_res, # may be None
"decode": decode_res, # may be None
}
@app.post("/start_profile")
async def start_profile(request: Request):
body = await request.json()
# TODO: handle multi urls properly
e_url = random.choice(app.state.e_urls)
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
d_url = random.choice(app.state.d_urls)
return await _profile_cmd("start", body, e_url, p_url, d_url)
@app.post("/stop_profile")
async def stop_profile(request: Request):
body = await request.json()
# TODO: handle multi urls properly
e_url = random.choice(app.state.e_urls)
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
d_url = random.choice(app.state.d_urls)
return await _profile_cmd("stop", body, e_url, p_url, d_url)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument(
"--encode-servers-urls",
required=True,
help='Comma-separated encode URLs ("http://e1:8001,http://e2:8001")',
)
parser.add_argument(
"--prefill-servers-urls",
required=True,
help=(
'Comma-separated prefill URLs ("http://p1:8003,http://p2:8004") ',
'to enable E->P->D, set "disable" or "none" to enable E->PD',
),
)
parser.add_argument(
"--decode-servers-urls",
required=True,
help='Comma-separated decode URLs ("http://d1:8005,http://d2:8006")',
)
args = parser.parse_args()
app.state.e_urls = [
u.strip() for u in args.encode_servers_urls.split(",") if u.strip()
]
app.state.d_urls = [
u.strip() for u in args.decode_servers_urls.split(",") if u.strip()
]
# handle prefill instances
if args.prefill_servers_urls.lower() in ("disable", "none", ""):
app.state.p_urls = []
logger.info(
"Disaggregated prefill phase explicitly disabled by user. Running E + PD..."
)
else:
app.state.p_urls = [
u.strip() for u in args.prefill_servers_urls.split(",") if u.strip()
]
logger.info("Disaggregated prefill phase is enabled. Running E + P + D...")
logger.info("Proxy listening on %s:%s", args.host, args.port)
logger.info("Encode servers: %s", app.state.e_urls)
logger.info("Prefill instances %s", app.state.p_urls)
logger.info("Decode servers: %s", app.state.d_urls)
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="info",
loop="uvloop",
access_log=True,
)
......@@ -23,7 +23,7 @@ import logging
import os
import sys
from abc import ABC, abstractmethod
from typing import Callable, Optional
from collections.abc import Callable
import aiohttp
import requests
......@@ -49,12 +49,9 @@ class Proxy:
decode_instances: list[str],
model: str,
scheduling_policy: SchedulingPolicy,
custom_create_completion: Optional[
Callable[[Request], StreamingResponse]
] = None,
custom_create_chat_completion: Optional[
Callable[[Request], StreamingResponse]
] = None,
custom_create_completion: Callable[[Request], StreamingResponse] | None = None,
custom_create_chat_completion: Callable[[Request], StreamingResponse]
| None = None,
):
self.prefill_instances = prefill_instances
self.decode_instances = decode_instances
......@@ -203,9 +200,9 @@ class Proxy:
async with session.post(
url=url, json=data, headers=headers
) as response:
if 200 <= response.status < 300 or 400 <= response.status < 500: # noqa: E501
if 200 <= response.status < 300 or 400 <= response.status < 500:
if use_chunked:
async for chunk_bytes in response.content.iter_chunked( # noqa: E501
async for chunk_bytes in response.content.iter_chunked(
1024
):
yield chunk_bytes
......@@ -348,9 +345,9 @@ class ProxyServer:
def __init__(
self,
args: argparse.Namespace,
scheduling_policy: Optional[SchedulingPolicy] = None,
create_completion: Optional[Callable[[Request], StreamingResponse]] = None,
create_chat_completion: Optional[Callable[[Request], StreamingResponse]] = None,
scheduling_policy: SchedulingPolicy | None = None,
create_completion: Callable[[Request], StreamingResponse] | None = None,
create_chat_completion: Callable[[Request], StreamingResponse] | None = None,
):
self.validate_parsed_serve_args(args)
self.port = args.port
......
......@@ -166,7 +166,7 @@ main() {
local kv_port=$((21001 + i))
echo " Prefill server $((i+1)): GPU $gpu_id, Port $port, KV Port $kv_port"
CUDA_VISIBLE_DEVICES=$gpu_id VLLM_USE_V1=1 vllm serve $MODEL \
CUDA_VISIBLE_DEVICES=$gpu_id vllm serve $MODEL \
--enforce-eager \
--host 0.0.0.0 \
--port $port \
......@@ -194,7 +194,7 @@ main() {
local kv_port=$((22001 + i))
echo " Decode server $((i+1)): GPU $gpu_id, Port $port, KV Port $kv_port"
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=$gpu_id vllm serve $MODEL \
CUDA_VISIBLE_DEVICES=$gpu_id vllm serve $MODEL \
--enforce-eager \
--host 0.0.0.0 \
--port $port \
......
......@@ -55,7 +55,6 @@ done
echo "Starting vLLM server for $MODEL_NAME with data parallel size: $DATA_PARALLEL_SIZE and redundant experts: $REDUNDANT_EXPERTS"
export RAY_DEDUP_LOGS=0
export VLLM_USE_V1=1
export VLLM_ALL2ALL_BACKEND="pplx"
export VLLM_USE_DEEP_GEMM=1
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional, Union
from typing import Any
import msgspec
import zmq
from msgspec.msgpack import Decoder
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.core.kv_cache_utils import ExternalBlockHash
#
......@@ -24,17 +24,17 @@ class KVCacheEvent(
class BlockStored(KVCacheEvent):
block_hashes: list[BlockHash]
parent_block_hash: Optional[BlockHash]
block_hashes: list[ExternalBlockHash]
parent_block_hash: ExternalBlockHash | None
token_ids: list[int]
block_size: int
lora_id: Optional[int]
medium: Optional[str]
lora_id: int | None
medium: str | None
class BlockRemoved(KVCacheEvent):
block_hashes: list[BlockHash]
medium: Optional[str]
block_hashes: list[ExternalBlockHash]
medium: str | None
class AllBlocksCleared(KVCacheEvent):
......@@ -42,7 +42,7 @@ class AllBlocksCleared(KVCacheEvent):
class KVEventBatch(EventBatch):
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
events: list[BlockStored | BlockRemoved | AllBlocksCleared]
def process_event(event_batch):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from typing import Optional
import threading
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.loggers import AggregatedLoggingStatLogger
"""
To run this example, run the following commands simultaneously with
......@@ -22,37 +23,64 @@ send a request to the instance with DP rank 1.
"""
def _do_background_logging(engine, interval, stop_event):
try:
while not stop_event.is_set():
asyncio.run(engine.do_log_stats())
stop_event.wait(interval)
except Exception as e:
print(f"vLLM background logging shutdown: {e}")
pass
async def main():
engine_args = AsyncEngineArgs(
model="ibm-research/PowerMoE-3b",
data_parallel_size=2,
tensor_parallel_size=1,
dtype="auto",
max_model_len=2048,
data_parallel_address="127.0.0.1",
data_parallel_rpc_port=62300,
data_parallel_size_local=1,
enforce_eager=True,
enable_log_requests=True,
disable_custom_all_reduce=True,
)
engine_client = AsyncLLMEngine.from_engine_args(engine_args)
engine_client = AsyncLLMEngine.from_engine_args(
engine_args,
# Example: Using aggregated logger
stat_loggers=[AggregatedLoggingStatLogger],
)
stop_logging_event = threading.Event()
logging_thread = threading.Thread(
target=_do_background_logging,
args=(engine_client, 5, stop_logging_event),
daemon=True,
)
logging_thread.start()
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.9,
max_tokens=100,
)
num_prompts = 10
for i in range(num_prompts):
prompt = "Who won the 2004 World Series?"
final_output: RequestOutput | None = None
async for output in engine_client.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id=f"abcdef-{i}",
data_parallel_rank=1,
):
final_output = output
if final_output:
print(final_output.outputs[0].text)
prompt = "Who won the 2004 World Series?"
final_output: Optional[RequestOutput] = None
async for output in engine_client.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id="abcdef",
data_parallel_rank=1,
):
final_output = output
if final_output:
print(final_output.outputs[0].text)
stop_logging_event.set()
logging_thread.join()
if __name__ == "__main__":
......
......@@ -26,7 +26,7 @@ import requests
from openai import OpenAI
from utils import get_first_model
from vllm.utils import FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
......@@ -38,11 +38,13 @@ client = OpenAI(
base_url=openai_api_base,
)
headers = {"User-Agent": "vLLM Example Client"}
def encode_base64_content_from_url(content_url: str) -> str:
"""Encode a content retrieved from a remote url to base64 format."""
with requests.get(content_url) as response:
with requests.get(content_url, headers=headers) as response:
response.raise_for_status()
result = base64.b64encode(response.content).decode("utf-8")
......@@ -50,19 +52,19 @@ def encode_base64_content_from_url(content_url: str) -> str:
# Text-only inference
def run_text_only(model: str) -> None:
def run_text_only(model: str, max_completion_tokens: int) -> None:
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": "What's the capital of France?"}],
model=model,
max_completion_tokens=64,
max_completion_tokens=max_completion_tokens,
)
result = chat_completion.choices[0].message.content
print("Chat completion output:", result)
print("Chat completion output:\n", result)
# Single-image input inference
def run_single_image(model: str) -> None:
def run_single_image(model: str, max_completion_tokens: int) -> None:
## Use image url in the payload
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
chat_completion_from_url = client.chat.completions.create(
......@@ -79,11 +81,11 @@ def run_single_image(model: str) -> None:
}
],
model=model,
max_completion_tokens=64,
max_completion_tokens=max_completion_tokens,
)
result = chat_completion_from_url.choices[0].message.content
print("Chat completion output from image url:", result)
print("Chat completion output from image url:\n", result)
## Use base64 encoded image in the payload
image_base64 = encode_base64_content_from_url(image_url)
......@@ -101,7 +103,7 @@ def run_single_image(model: str) -> None:
}
],
model=model,
max_completion_tokens=64,
max_completion_tokens=max_completion_tokens,
)
result = chat_completion_from_base64.choices[0].message.content
......@@ -109,9 +111,9 @@ def run_single_image(model: str) -> None:
# Multi-image input inference
def run_multi_image(model: str) -> None:
image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"
def run_multi_image(model: str, max_completion_tokens: int) -> None:
image_url_duck = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg"
image_url_lion = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg"
chat_completion_from_url = client.chat.completions.create(
messages=[
{
......@@ -130,15 +132,15 @@ def run_multi_image(model: str) -> None:
}
],
model=model,
max_completion_tokens=64,
max_completion_tokens=max_completion_tokens,
)
result = chat_completion_from_url.choices[0].message.content
print("Chat completion output:", result)
print("Chat completion output:\n", result)
# Video input inference
def run_video(model: str) -> None:
def run_video(model: str, max_completion_tokens: int) -> None:
video_url = "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerFun.mp4"
video_base64 = encode_base64_content_from_url(video_url)
......@@ -157,11 +159,11 @@ def run_video(model: str) -> None:
}
],
model=model,
max_completion_tokens=64,
max_completion_tokens=max_completion_tokens,
)
result = chat_completion_from_url.choices[0].message.content
print("Chat completion output from image url:", result)
print("Chat completion output from video url:\n", result)
## Use base64 encoded video in the payload
chat_completion_from_base64 = client.chat.completions.create(
......@@ -178,15 +180,15 @@ def run_video(model: str) -> None:
}
],
model=model,
max_completion_tokens=64,
max_completion_tokens=max_completion_tokens,
)
result = chat_completion_from_base64.choices[0].message.content
print("Chat completion output from base64 encoded image:", result)
print("Chat completion output from base64 encoded video:\n", result)
# Audio input inference
def run_audio(model: str) -> None:
def run_audio(model: str, max_completion_tokens: int) -> None:
from vllm.assets.audio import AudioAsset
audio_url = AudioAsset("winning_call").url
......@@ -211,11 +213,11 @@ def run_audio(model: str) -> None:
}
],
model=model,
max_completion_tokens=64,
max_completion_tokens=max_completion_tokens,
)
result = chat_completion_from_base64.choices[0].message.content
print("Chat completion output from input audio:", result)
print("Chat completion output from input audio:\n", result)
# HTTP URL
chat_completion_from_url = client.chat.completions.create(
......@@ -235,11 +237,11 @@ def run_audio(model: str) -> None:
}
],
model=model,
max_completion_tokens=64,
max_completion_tokens=max_completion_tokens,
)
result = chat_completion_from_url.choices[0].message.content
print("Chat completion output from audio url:", result)
print("Chat completion output from audio url:\n", result)
# base64 URL
chat_completion_from_base64 = client.chat.completions.create(
......@@ -259,14 +261,14 @@ def run_audio(model: str) -> None:
}
],
model=model,
max_completion_tokens=64,
max_completion_tokens=max_completion_tokens,
)
result = chat_completion_from_base64.choices[0].message.content
print("Chat completion output from base64 encoded audio:", result)
print("Chat completion output from base64 encoded audio:\n", result)
def run_multi_audio(model: str) -> None:
def run_multi_audio(model: str, max_completion_tokens: int) -> None:
from vllm.assets.audio import AudioAsset
# Two different audios to showcase batched inference.
......@@ -300,11 +302,11 @@ def run_multi_audio(model: str) -> None:
}
],
model=model,
max_completion_tokens=64,
max_completion_tokens=max_completion_tokens,
)
result = chat_completion_from_base64.choices[0].message.content
print("Chat completion output from input audio:", result)
print("Chat completion output from input audio:\n", result)
example_function_map = {
......@@ -330,13 +332,20 @@ def parse_args():
choices=list(example_function_map.keys()),
help="Conversation type with multimodal data.",
)
parser.add_argument(
"--max-completion-tokens",
"-n",
type=int,
default=128,
help="Maximum number of tokens to generate for each completion.",
)
return parser.parse_args()
def main(args) -> None:
chat_type = args.chat_type
model = get_first_model(client)
example_function_map[chat_type](model)
example_function_map[chat_type](model, args.max_completion_tokens)
if __name__ == "__main__":
......
......@@ -161,6 +161,7 @@ def main():
{
"role": "assistant",
"tool_calls": chat_completion.choices[0].message.tool_calls,
"reasoning": chat_completion.choices[0].message.reasoning,
}
)
......
......@@ -5,7 +5,7 @@ To run this example, you can start the vLLM server
without any specific flags:
```bash
VLLM_USE_V1=0 vllm serve unsloth/Llama-3.2-1B-Instruct \
vllm serve unsloth/Llama-3.2-1B-Instruct \
--structured-outputs-config.backend outlines
```
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
An example demonstrates how to use tool calling with reasoning models
like QwQ-32B. The reasoning_content will not be parsed by the tool
like QwQ-32B. The reasoning will not be parsed by the tool
calling process; only the final output will be parsed.
To run this example, you need to start the vLLM server with both
......@@ -78,7 +78,7 @@ messages = [
def extract_reasoning_and_calls(chunks: list):
reasoning_content = ""
reasoning = ""
tool_call_idx = -1
arguments = []
function_names = []
......@@ -97,9 +97,9 @@ def extract_reasoning_and_calls(chunks: list):
if tool_call.function.arguments:
arguments[tool_call_idx] += tool_call.function.arguments
else:
if hasattr(chunk.choices[0].delta, "reasoning_content"):
reasoning_content += chunk.choices[0].delta.reasoning_content
return reasoning_content, arguments, function_names
if hasattr(chunk.choices[0].delta, "reasoning"):
reasoning += chunk.choices[0].delta.reasoning
return reasoning, arguments, function_names
def main():
......@@ -115,7 +115,7 @@ def main():
tool_calls = client.chat.completions.create(
messages=messages, model=model, tools=tools
)
print(f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}")
print(f"reasoning: {tool_calls.choices[0].message.reasoning}")
print(f"function name: {tool_calls.choices[0].message.tool_calls[0].function.name}")
print(
f"function arguments: "
......@@ -129,9 +129,9 @@ def main():
chunks = list(tool_calls_stream)
reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks)
reasoning, arguments, function_names = extract_reasoning_and_calls(chunks)
print(f"reasoning_content: {reasoning_content}")
print(f"reasoning: {reasoning}")
print(f"function name: {function_names[0]}")
print(f"function arguments: {arguments[0]}")
......@@ -144,7 +144,7 @@ def main():
)
tool_call = tool_calls.choices[0].message.tool_calls[0].function
print(f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}")
print(f"reasoning: {tool_calls.choices[0].message.reasoning}")
print(f"function name: {tool_call.name}")
print(f"function arguments: {tool_call.arguments}")
print("----------Stream Generate With Named Function Calling--------------")
......@@ -159,8 +159,8 @@ def main():
chunks = list(tool_calls_stream)
reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks)
print(f"reasoning_content: {reasoning_content}")
reasoning, arguments, function_names = extract_reasoning_and_calls(chunks)
print(f"reasoning: {reasoning}")
print(f"function name: {function_names[0]}")
print(f"function arguments: {arguments[0]}")
print("\n\n")
......
......@@ -38,10 +38,10 @@ def main():
# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
response = client.chat.completions.create(model=model, messages=messages)
reasoning_content = response.choices[0].message.reasoning_content
reasoning = response.choices[0].message.reasoning
content = response.choices[0].message.content
print("reasoning_content for Round 1:", reasoning_content)
print("reasoning for Round 1:", reasoning)
print("content for Round 1:", content)
# Round 2
......@@ -54,10 +54,10 @@ def main():
)
response = client.chat.completions.create(model=model, messages=messages)
reasoning_content = response.choices[0].message.reasoning_content
reasoning = response.choices[0].message.reasoning
content = response.choices[0].message.content
print("reasoning_content for Round 2:", reasoning_content)
print("reasoning for Round 2:", reasoning)
print("content for Round 2:", content)
......
......@@ -20,7 +20,7 @@ in real-time as they are generated by the model. This is useful for scenarios
where you want to display chat completions to the user as they are generated
by the model.
Remember to check content and reasoning_content exist in `ChatCompletionChunk`,
Remember to check content and reasoning exist in `ChatCompletionChunk`,
content may not exist leading to errors if you try to access it.
"""
......@@ -47,22 +47,20 @@ def main():
stream = client.chat.completions.create(model=model, messages=messages, stream=True)
print("client: Start streaming chat completions...")
printed_reasoning_content = False
printed_reasoning = False
printed_content = False
for chunk in stream:
# Safely extract reasoning_content and content from delta,
# Safely extract reasoning and content from delta,
# defaulting to None if attributes don't exist or are empty strings
reasoning_content = (
getattr(chunk.choices[0].delta, "reasoning_content", None) or None
)
reasoning = getattr(chunk.choices[0].delta, "reasoning", None) or None
content = getattr(chunk.choices[0].delta, "content", None) or None
if reasoning_content is not None:
if not printed_reasoning_content:
printed_reasoning_content = True
print("reasoning_content:", end="", flush=True)
print(reasoning_content, end="", flush=True)
if reasoning is not None:
if not printed_reasoning:
printed_reasoning = True
print("reasoning:", end="", flush=True)
print(reasoning, end="", flush=True)
elif content is not None:
if not printed_content:
printed_content = True
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Set up this example by starting a vLLM OpenAI-compatible server with tool call
options enabled.
Reasoning models can be used through the Responses API as seen here
https://platform.openai.com/docs/api-reference/responses
For example:
vllm serve Qwen/Qwen3-1.7B --reasoning-parser qwen3 \
--structured-outputs-config.backend xgrammar \
--enable-auto-tool-choice --tool-call-parser hermes
"""
import json
from openai import OpenAI
from utils import get_first_model
def get_weather(latitude: float, longitude: float) -> str:
"""
Mock function to simulate getting weather data.
In a real application, this would call an external weather API.
"""
return f"Current temperature at ({latitude}, {longitude}) is 20°C."
tools = [
{
"type": "function",
"name": "get_weather",
"description": "Get current temperature for provided coordinates in celsius.",
"parameters": {
"type": "object",
"properties": {
"latitude": {"type": "number"},
"longitude": {"type": "number"},
},
"required": ["latitude", "longitude"],
"additionalProperties": False,
},
"strict": True,
}
]
input_messages = [
{"role": "user", "content": "What's the weather like in Paris today?"}
]
def main():
base_url = "http://0.0.0.0:8000/v1"
client = OpenAI(base_url=base_url, api_key="empty")
model = get_first_model(client)
response = client.responses.create(
model=model, input=input_messages, tools=tools, tool_choice="required"
)
for out in response.output:
if out.type == "function_call":
print("Function call:", out.name, out.arguments)
tool_call = out
args = json.loads(tool_call.arguments)
result = get_weather(args["latitude"], args["longitude"])
input_messages.append(tool_call) # append model's function call message
input_messages.append(
{ # append result message
"type": "function_call_output",
"call_id": tool_call.call_id,
"output": str(result),
}
)
response_2 = client.responses.create(
model=model,
input=input_messages,
tools=tools,
)
print(response_2.output_text)
if __name__ == "__main__":
main()
......@@ -3,47 +3,95 @@
## Cohere rerank usage
```bash
# vllm serve BAAI/bge-reranker-base
python examples/online_serving/pooling/cohere_rerank_client.py
```
## Embedding requests base64 encoding_format usage
```bash
# vllm serve intfloat/e5-small
python examples/online_serving/pooling/embedding_requests_base64_client.py
```
## Embedding requests bytes encoding_format usage
```bash
# vllm serve intfloat/e5-small
python examples/online_serving/pooling/embedding_requests_bytes_client.py
```
## Jinaai rerank usage
```bash
# vllm serve BAAI/bge-reranker-base
python examples/online_serving/pooling/jinaai_rerank_client.py
```
## Multi vector retrieval usage
```bash
# vllm serve BAAI/bge-m3
python examples/online_serving/pooling/multi_vector_retrieval_client.py
```
## Named Entity Recognition (NER) usage
```bash
python examples/online_serving/pooling/ner.py
# vllm serve boltuix/NeuroBERT-NER
python examples/online_serving/pooling/ner_client.py
```
## Openai chat embedding for multimodal usage
## OpenAI chat embedding for multimodal usage
```bash
python examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py
```
## Openai classification usage
## OpenAI classification usage
```bash
# vllm serve jason9693/Qwen2.5-1.5B-apeach
python examples/online_serving/pooling/openai_classification_client.py
```
## Openai embedding usage
## OpenAI cross_encoder score usage
```bash
# vllm serve BAAI/bge-reranker-v2-m3
python examples/online_serving/pooling/openai_cross_encoder_score.py
```
## OpenAI cross_encoder score for multimodal usage
```bash
# vllm serve jinaai/jina-reranker-m0
python examples/online_serving/pooling/openai_cross_encoder_score_for_multimodal.py
```
## OpenAI embedding usage
```bash
# vllm serve intfloat/e5-small
python examples/online_serving/pooling/openai_embedding_client.py
```
## Openai embedding matryoshka dimensions usage
## OpenAI embedding matryoshka dimensions usage
```bash
# vllm serve jinaai/jina-embeddings-v3 --trust-remote-code
python examples/online_serving/pooling/openai_embedding_matryoshka_fy.py
```
## Openai pooling usage
## OpenAI pooling usage
```bash
# vllm serve internlm/internlm2-1_8b-reward --trust-remote-code
python examples/online_serving/pooling/openai_pooling_client.py
```
## Online Prithvi Geospatial MAE usage
```bash
python examples/online_serving/pooling/prithvi_geospatial_mae.py
```
......@@ -8,8 +8,6 @@ Note that `pip install cohere` is needed to run this example.
run: vllm serve BAAI/bge-reranker-base
"""
from typing import Union
import cohere
from cohere import Client, ClientV2
......@@ -25,7 +23,7 @@ documents = [
def cohere_rerank(
client: Union[Client, ClientV2], model: str, query: str, documents: list[str]
client: Client | ClientV2, model: str, query: str, documents: list[str]
) -> dict:
return client.rerank(model=model, query=query, documents=documents)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment