Unverified Commit 4ccffe56 authored by Chenguang Zheng's avatar Chenguang Zheng Committed by GitHub
Browse files

[Core] Encoder separation for Encode-Prefill-Decode Disaggregation (#25233)


Signed-off-by: default avatarn00909098 <nguyen.kha.long@huawei.com>
Signed-off-by: default avatarknlnguyen1802 <knlnguyen1802@gmail.com>
Signed-off-by: default avatarherotai214 <herotai214@gmail.com>
Signed-off-by: default avatarKhuong Le <khuong.le.manh@huawei.com>
Signed-off-by: default avatarKhuong Le <lemanhkhuong2611@gmail.com>
Co-authored-by: default avatarn00909098 <nguyen.kha.long@huawei.com>
Co-authored-by: default avatarknlnguyen1802 <knlnguyen1802@gmail.com>
Co-authored-by: default avatarherotai214 <herotai214@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarKhuong Le <khuong.le.manh@huawei.com>
Co-authored-by: default avatarKhuong Le <lemanhkhuong2611@gmail.com>
parent cbb799e3
# Disaggregated Encoder
A **disaggregated encoder** runs the vision-encoder stage of a multimodal LLM in a process that is separate from the pre-fill / decoder stage. Deploying these two stages in independent vLLM instances brings three practical benefits:
1. **Independent, fine-grained scaling**
2. **Lower time-to-first-token (TTFT)**
3. **Cross-process reuse and caching of encoder outputs**
Design doc: <https://docs.google.com/document/d/1aed8KtC6XkXtdoV87pWT0a8OJlZ-CpnuLLzmR8l9BAE>
---
## 1 Motivation
### 1. Independent, fine-grained scaling
* Vision encoders are lightweight, while language models are orders of magnitude larger.
* The language model can be parallelised without affecting the encoder fleet.
* Encoder nodes can be added or removed independently.
### 2. Lower time-to-first-token (TTFT)
* Language-only requests bypass the vision encoder entirely.
* Encoder output is injected only at required attention layers, shortening the pre-fill critical path.
### 3. Cross-process reuse and caching
* In-process encoders confine reuse to a single worker.
* A remote, shared cache lets any worker retrieve existing embeddings, eliminating redundant computation.
---
## 2 Usage Example
The current reference pathway is **SharedStorageConnector**.
Below ready-to-run scripts shows the workflow:
1 Encoder instance + 1 PD instance:
`examples/online_serving/disaggregated_encoder/shared_storage_connector/disagg_encoder_example.sh`
1 Encoder instance + 1 Prefill instance + 1 Decode instance:
`examples/online_serving/disaggregated_encoder/shared_storage_connector/disagg_epd_example.sh`
---
## 3 Test Script
Please refer to the directories `tests/v1/ec_connector`
## 4 Development
Disaggregated encoding is implemented by running two parts:
* **Encoder instance** – a vLLM instance to performs vision encoding.
* **Prefill/Decode (PD) instance(s)** – runs language pre-fill and decode.
* PD can be in either a single normal instance with `disagg_encoder_example.sh` (E->PD) or in disaggregated instances with `disagg_epd_example.sh` (E->P->D)
A connector transfers encoder-cache (EC) embeddings from the encoder instance to the PD instance.
All related code is under `vllm/distributed/ec_transfer`.
### Key abstractions
* **ECConnector** – interface for retrieving EC caches produced by the encoder.
* *Scheduler role* – checks cache existence and schedules loads.
* *Worker role* – loads the embeddings into memory.
Here is a figure illustrating disaggregate encoder flow:
![Disaggregated Encoder Flow](../assets/features/disagg_encoder/disagg_encoder_flow.png)
For the PD disaggregation part, the Prefill instance receive cache exactly the same as the disaggregate encoder flow above. Prefill instance executes 1 step (prefill -> 1 token output) and then transfer KV cache to the Decode instance for the remaining execution. The KV transfer part purely happens after the execute of the PDinstance.
`docs/features/disagg_prefill.md` shows the brief idea about the disaggregated prefill (v0)
We create the example setup with the **NixlConnector** from `vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py` and referred to the `tests/v1/kv_connector/nixl_integration/toy_proxy_server.py` to facilitate the kv transfer between P and D;
# Disaggregated Encoder
These example scripts that demonstrate the disaggregated encoder (EPD) features of vLLM.
For a detailed explanation of the EPD features, please refer to the [Disaggregated Encoder Feature Documentation](../../../docs/features/disagg_encoder.md).
## Files
- `disagg_epd_proxy.py` - Proxy script that demonstrates the XeYpZd setup (X encode instances, Y prefill instances, Z decode instances). Currently stable for the 1e1p1d configuration.
- `disagg_1e1p1d_example.sh` - Sets up the 1e1p1d configuration, runs the VisionArena benchmark, and processes a single request with a local image.
- `disagg_1e1pd_example.sh` - Sets up the 1e1pd configuration, runs the VisionArena benchmark, and processes a single request with a local image.
### Custom Configuration
```bash
# Use specific GPUs
GPU_E=0 GPU_PD=1 GPU_P=1 GPU_D=2 bash disagg_1e1p1d_example.sh
# Use specific ports
ENDPOINT_PORT=10001 bash disagg_1e1p1d_example.sh
# Use specific model
MODEL="Qwen/Qwen2.5-VL-3B-Instruct" bash disagg_1e1p1d_example.sh
# Use specific storage path
EC_SHARED_STORAGE_PATH="/tmp/my_ec_cache" bash disagg_1e1p1d_example.sh
```
## Encoder Instances
Encoder engines should be launched with the following flags:
- `--enforce-eager` **(required)** – The current EPD implementation is only compatible with encoder instances running in this mode.
- `--no-enable-prefix-caching` **(required)** – Encoder instances do not consume KV cache; prefix caching is disabled to avoid conflicts with other features.
- `--max-num-batched-tokens=<large value>` **(default: 2048)** – This flag controls the token scheduling budget per decoding step and is irrelevant to encoder-only instances. **Set it to a very high value (effectively unlimited) to bypass scheduler limitations.** The actual token budget is managed by the encoder cache manager.
## Local media inputs
To support local image inputs (from your ```MEDIA_PATH``` directory), add the following flag to the encoder instance:
```bash
--allowed-local-media-path $MEDIA_PATH
```
The vllm instances and `disagg_encoder_proxy` supports local URIs with ```{"url": "file://'"$MEDIA_PATH_FILENAME"'}``` as multimodal inputs. Each URI is passed unchanged from the `disagg_encoder_proxy` to the encoder instance so that the encoder can load the media locally.
## EC connector and KV transfer
The `ECSharedStorageConnector` is used to store the encoder cache on local disk and facilitate transfer. To enable the encoder disaggregation feature, add the following configuration:
```bash
# Add to encoder instance:
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_role": "ec_producer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}'
# Add to prefill/prefill+decode instance:
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_role": "ec_consumer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}'
```
`$EC_SHARED_STORAGE_PATH` is the path where the EC connector temporarily stores the cache.
If you enable prefill instance (`--prefill-servers-urls` not disabled), you will need --kv-transfer-config to facilitate the PD disaggregation. Currently, we use the `NixlConnector` for this purpose. Refer to `tests/v1/kv_connector/nixl_integration` for more example codes on PD disaggregation with Nixl.
```bash
# Add to prefill instance:
--kv-transfer-config '{
"kv_connector": "NixlConnector",
"kv_role": "kv_producer"
}'
# Add to decode instance:
--kv-transfer-config '{
"kv_connector": "NixlConnector",
"kv_role": "kv_consumer"
}'
```
## Proxy Instance Flags (`disagg_epd_proxy.py`)
| Flag | Description |
|------|-------------|
| `--encode-servers-urls` | Comma-separated list of encoder endpoints. Every multimodal item extracted from the request is fanned out to one of these URLs in a round-robin fashion. |
| `--prefill-servers-urls` | Comma-separated list of prefill endpoints. Set to `disable`, `none`, or `""` to skip the dedicated prefill phase and run E+PD (encoder + combined prefill/decode). |
| `--decode-servers-urls` | Comma-separated list of decode endpoints. Non-stream and stream paths both round-robin over this list. |
| `--host`, `--port` | Bind address for the proxy itself (defaults: `0.0.0.0:8000`). |
Example usage:
For E + PD setup:
```bash
$ python disagg_encoder_proxy.py \
--encode-servers-urls "http://e1:8001,http://e2:8002" \
--prefill-servers-urls "disable" \
--decode-servers-urls "http://pd1:8003,http://pd2:8004"
```
For E + P + D setup:
```bash
$ python disagg_encoder_proxy.py \
--encode-servers-urls "http://e1:8001,http://e2:8001" \
--prefill-servers-urls "http://p1:8003,http://p2:8004" \
--decode-servers-urls "http://d1:8005,http://d2:8006"
```
#!/bin/bash
set -euo pipefail
declare -a PIDS=()
###############################################################################
# Configuration -- override via env before running
###############################################################################
MODEL="${MODEL:-Qwen/Qwen2.5-VL-3B-Instruct}"
LOG_PATH="${LOG_PATH:-./logs}"
mkdir -p $LOG_PATH
ENCODE_PORT="${ENCODE_PORT:-19534}"
PREFILL_PORT="${PREFILL_PORT:-19535}"
DECODE_PORT="${DECODE_PORT:-19536}"
PROXY_PORT="${PROXY_PORT:-10001}"
GPU_E="${GPU_E:-2}"
GPU_P="${GPU_P:-2}"
GPU_D="${GPU_D:-3}"
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}"
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout
NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark
export UCX_TLS=all
export UCX_NET_DEVICES=all
###############################################################################
# Helpers
###############################################################################
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
START_TIME=$(date +"%Y%m%d_%H%M%S")
ENC_LOG=$LOG_PATH/encoder_${START_TIME}.log
P_LOG=$LOG_PATH/p_${START_TIME}.log
D_LOG=$LOG_PATH/d_${START_TIME}.log
PROXY_LOG=$LOG_PATH/proxy_${START_TIME}.log
wait_for_server() {
local port=$1
timeout "$TIMEOUT_SECONDS" bash -c "
until curl -s localhost:$port/v1/chat/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Cleanup function
cleanup() {
echo "Stopping everything…"
trap - INT TERM USR1 # prevent re-entrancy
# Kill all tracked PIDs
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Killing process $pid"
kill "$pid" 2>/dev/null
fi
done
# Wait a moment for graceful shutdown
sleep 2
# Force kill any remaining processes
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Force killing process $pid"
kill -9 "$pid" 2>/dev/null
fi
done
# Kill the entire process group as backup
kill -- -$$ 2>/dev/null
echo "All processes stopped."
exit 0
}
trap cleanup INT
trap cleanup USR1
trap cleanup TERM
# clear previous cache
echo "remove previous ec cache folder"
rm -rf $EC_SHARED_STORAGE_PATH
echo "make ec cache folder"
mkdir -p $EC_SHARED_STORAGE_PATH
###############################################################################
# Encoder worker
###############################################################################
CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
--gpu-memory-utilization 0.01 \
--port "$ENCODE_PORT" \
--enforce-eager \
--enable-request-id-headers \
--no-enable-prefix-caching \
--max-num-batched-tokens 114688 \
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_role": "ec_producer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}' \
>"${ENC_LOG}" 2>&1 &
PIDS+=($!)
###############################################################################
# Prefill worker
###############################################################################
CUDA_VISIBLE_DEVICES="$GPU_P" \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \
vllm serve "$MODEL" \
--gpu-memory-utilization 0.7 \
--port "$PREFILL_PORT" \
--enforce-eager \
--enable-request-id-headers \
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_role": "ec_consumer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}' \
--kv-transfer-config '{
"kv_connector": "NixlConnector",
"kv_role": "kv_producer"
}' \
>"${P_LOG}" 2>&1 &
PIDS+=($!)
###############################################################################
# Decode worker
###############################################################################
CUDA_VISIBLE_DEVICES="$GPU_D" \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=6000 \
vllm serve "$MODEL" \
--gpu-memory-utilization 0.7 \
--port "$DECODE_PORT" \
--enforce-eager \
--enable-request-id-headers \
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--kv-transfer-config '{
"kv_connector": "NixlConnector",
"kv_role": "kv_consumer"
}' \
>"${D_LOG}" 2>&1 &
PIDS+=($!)
# Wait for workers
wait_for_server $ENCODE_PORT
wait_for_server $PREFILL_PORT
wait_for_server $DECODE_PORT
###############################################################################
# Proxy
###############################################################################
python disagg_epd_proxy.py \
--host "0.0.0.0" \
--port "$PROXY_PORT" \
--encode-servers-urls "http://localhost:$ENCODE_PORT" \
--prefill-servers-urls "http://localhost:$PREFILL_PORT" \
--decode-servers-urls "http://localhost:$DECODE_PORT" \
>"${PROXY_LOG}" 2>&1 &
PIDS+=($!)
wait_for_server $PROXY_PORT
echo "All services are up!"
###############################################################################
# Benchmark
###############################################################################
echo "Running benchmark (stream)..."
vllm bench serve \
--model $MODEL \
--backend openai-chat \
--endpoint /v1/chat/completions \
--dataset-name hf \
--dataset-path lmarena-ai/VisionArena-Chat \
--seed 0 \
--num-prompts $NUM_PROMPTS \
--port $PROXY_PORT
PIDS+=($!)
###############################################################################
# Single request with local image
###############################################################################
echo "Running single request with local image (non-stream)..."
curl http://127.0.0.1:${PROXY_PORT}/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "'${MODEL}'",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [
{"type": "image_url", "image_url": {"url": "file://'"${GIT_ROOT}"'/tests/v1/ec_connector/integration/hato.jpg"}},
{"type": "text", "text": "What is in this image?"}
]}
]
}'
# cleanup
echo "cleanup..."
cleanup
\ No newline at end of file
#!/bin/bash
set -euo pipefail
declare -a PIDS=()
###############################################################################
# Configuration -- override via env before running
###############################################################################
MODEL="${MODEL:-Qwen/Qwen2.5-VL-3B-Instruct}"
LOG_PATH="${LOG_PATH:-./logs}"
mkdir -p $LOG_PATH
ENCODE_PORT="${ENCODE_PORT:-19534}"
PREFILL_DECODE_PORT="${PREFILL_DECODE_PORT:-19535}"
PROXY_PORT="${PROXY_PORT:-10001}"
GPU_E="${GPU_E:-0}"
GPU_PD="${GPU_PD:-1}"
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}"
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout
NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark
###############################################################################
# Helpers
###############################################################################
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
START_TIME=$(date +"%Y%m%d_%H%M%S")
ENC_LOG=$LOG_PATH/encoder_${START_TIME}.log
PD_LOG=$LOG_PATH/pd_${START_TIME}.log
PROXY_LOG=$LOG_PATH/proxy_${START_TIME}.log
wait_for_server() {
local port=$1
timeout "$TIMEOUT_SECONDS" bash -c "
until curl -s localhost:$port/v1/chat/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Cleanup function
cleanup() {
echo "Stopping everything…"
trap - INT TERM USR1 # prevent re-entrancy
# Kill all tracked PIDs
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Killing process $pid"
kill "$pid" 2>/dev/null
fi
done
# Wait a moment for graceful shutdown
sleep 2
# Force kill any remaining processes
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Force killing process $pid"
kill -9 "$pid" 2>/dev/null
fi
done
# Kill the entire process group as backup
kill -- -$$ 2>/dev/null
echo "All processes stopped."
exit 0
}
trap cleanup INT
trap cleanup USR1
trap cleanup TERM
# clear previous cache
echo "remove previous ec cache folder"
rm -rf $EC_SHARED_STORAGE_PATH
echo "make ec cache folder"
mkdir -p $EC_SHARED_STORAGE_PATH
###############################################################################
# Encoder worker
###############################################################################
CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
--gpu-memory-utilization 0.01 \
--port "$ENCODE_PORT" \
--enforce-eager \
--enable-request-id-headers \
--no-enable-prefix-caching \
--max-num-batched-tokens 114688 \
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_role": "ec_producer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}' \
>"${ENC_LOG}" 2>&1 &
PIDS+=($!)
###############################################################################
# Prefill+Decode worker
###############################################################################
CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \
--gpu-memory-utilization 0.7 \
--port "$PREFILL_DECODE_PORT" \
--enforce-eager \
--enable-request-id-headers \
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_role": "ec_consumer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}' \
>"${PD_LOG}" 2>&1 &
PIDS+=($!)
# Wait for workers
wait_for_server $ENCODE_PORT
wait_for_server $PREFILL_DECODE_PORT
###############################################################################
# Proxy
###############################################################################
python disagg_epd_proxy.py \
--host "0.0.0.0" \
--port "$PROXY_PORT" \
--encode-servers-urls "http://localhost:$ENCODE_PORT" \
--prefill-servers-urls "disable" \
--decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT" \
>"${PROXY_LOG}" 2>&1 &
PIDS+=($!)
wait_for_server $PROXY_PORT
echo "All services are up!"
###############################################################################
# Benchmark
###############################################################################
echo "Running benchmark (stream)..."
vllm bench serve \
--model $MODEL \
--backend openai-chat \
--endpoint /v1/chat/completions \
--dataset-name hf \
--dataset-path lmarena-ai/VisionArena-Chat \
--seed 0 \
--num-prompts $NUM_PROMPTS \
--port $PROXY_PORT
PIDS+=($!)
###############################################################################
# Single request with local image
###############################################################################
echo "Running single request with local image (non-stream)..."
curl http://127.0.0.1:${PROXY_PORT}/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "'${MODEL}'",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [
{"type": "image_url", "image_url": {"url": "file://'"${GIT_ROOT}"'/tests/v1/ec_connector/integration/hato.jpg"}},
{"type": "text", "text": "What is in this image?"}
]}
]
}'
# cleanup
echo "cleanup..."
cleanup
\ No newline at end of file
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
disagg_encoder_proxy.py
Proxy that routes OpenAI-compatible “/v1/chat/completions” requests to two
clusters:
• encode (multimodal feature extraction)
• decode (language-model inference)
For MM input we:
1. Extract *every* image/audio item.
2. Fire N concurrent requests to the encoder cluster
(one request per item, with **all text removed**).
3. Wait for all of them to succeed.
4. Forward the *original* request to a decode server.
"""
from __future__ import annotations
import argparse
import asyncio
import logging
import os
import random
import uuid
from collections.abc import AsyncIterator
import aiohttp
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
###############################################################################
# FastAPI app & global state
###############################################################################
logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s %(levelname)s: %(message)s"
)
logger = logging.getLogger("proxy")
app = FastAPI()
encode_session: aiohttp.ClientSession | None = None
prefill_session: aiohttp.ClientSession | None = None
decode_session: aiohttp.ClientSession | None = None
###############################################################################
# Utils
###############################################################################
MM_TYPES = {"image_url", "audio_url", "input_audio"}
def extract_mm_items(request_data: dict) -> list[dict]:
"""
Return *all* image/audio items that appear anywhere in `messages`.
Each returned dict looks like:
{ "type": "image_url", "image_url": {...} }
"""
items: list[dict] = []
for msg in request_data.get("messages", []):
content = msg.get("content")
if not isinstance(content, list):
continue
for item in content:
if item.get("type") in MM_TYPES:
items.append(item)
return items
async def fanout_encoder_primer(
orig_request: dict,
e_urls: list[str],
req_id: str,
) -> None:
"""
1. Build one request *per MM item* with all text removed.
2. Send them concurrently to the encode cluster.
3. Raise if any of them fails.
"""
logger.info("[%s] Processing multimodal items...", req_id)
mm_items = extract_mm_items(orig_request)
if not mm_items:
logger.info("[%s] No multimodal items, skipping encoder", req_id)
return # nothing to do
logger.info("[%s] got %d multimodal items...", req_id, len(mm_items))
tasks = []
# Round-robin over encode servers to distribute load a bit
url_cycle = (e_urls[i % len(e_urls)] for i in range(len(mm_items)))
for idx, (item, target_url) in enumerate(zip(mm_items, url_cycle)):
# Derive a *child* request id: <parent>:<index>:<random-short>
child_req_id = f"{req_id}:{idx}:{uuid.uuid4().hex[:6]}"
headers = {"x-request-id": child_req_id}
encoder_req = {
# You *may* need to keep additional fields
"model": orig_request.get("model"),
"messages": [
{"role": "user", "content": [item]},
],
# Only need 1 token so the server actually runs the encoder path
"max_tokens": 1,
"stream": False,
}
tasks.append(
encode_session.post(
f"{target_url}/v1/chat/completions",
json=encoder_req,
headers=headers,
)
)
results = await asyncio.gather(*tasks, return_exceptions=True)
# Fail fast if any sub-request failed
for idx, r in enumerate(results):
if isinstance(r, Exception):
logger.error(
"[%s] Encoder request #%d raised exception: %s",
req_id,
idx,
r,
exc_info=r,
)
raise HTTPException(
status_code=502, detail=f"Encoder request failed: {str(r)}"
)
if r.status != 200:
try:
detail = await r.text()
except Exception:
detail = "<unable to read body>"
logger.error(
"[%s] Encoder request #%d returned status %s: %s",
req_id,
idx,
r.status,
detail,
)
raise HTTPException(
status_code=r.status,
detail=f"Encoder request failed: {detail}",
)
logger.info(
"[%s] All %d encoder requests completed successfully", req_id, len(mm_items)
)
async def maybe_prefill(
req_data: dict,
p_url: str,
req_id: str,
) -> dict:
"""
- Do prefill-only task if p_url exist;
- Return modified request data with kv transfer params (for nixl connector)
- Else, skip and return the original request data for decode
"""
if p_url:
logger.info("[%s] Processing through prefill: %s", req_id, p_url)
prefill_response = await process_prefill_stage(req_data, p_url, req_id)
# for nixl connector to facilitate kv transfer...
prefill_response_json = await prefill_response.json()
kv_transfer_params = prefill_response_json.get("kv_transfer_params", {})
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params
return req_data
else:
return req_data
async def process_prefill_stage(
req_data: dict,
p_url: str,
req_id: str,
) -> dict:
"""Process request through Prefill stage and return kv_transfer_params"""
logger.info("[%s] Sending prefill request to: %s", req_id, p_url)
prefill_request = req_data.copy()
prefill_request["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
}
prefill_request["stream"] = False
prefill_request["max_tokens"] = 1
if "max_completion_tokens" in prefill_request:
prefill_request["max_completion_tokens"] = 1
if "stream_options" in prefill_request:
del prefill_request["stream_options"]
headers = {"x-request-id": req_id}
try:
prefill_response = await prefill_session.post(
f"{p_url}/v1/chat/completions", json=prefill_request, headers=headers
)
prefill_response.raise_for_status()
if prefill_response.status != 200:
error_text = await prefill_response.text()
logger.error(
"[%s] Prefill request failed with status %d: %s",
req_id,
prefill_response.status,
error_text,
)
raise HTTPException(
status_code=prefill_response.status,
detail={"error": "Prefill request failed", "message": error_text},
)
logger.info("[%s] Prefill request completed successfully", req_id)
return prefill_response
except Exception as e:
logger.error("Prefill processing failed: %s", str(e))
raise HTTPException(
status_code=500,
detail={"error": "Prefill processing error", "message": str(e)},
) from e
###############################################################################
# Middleware for request/response logging
###############################################################################
@app.middleware("http")
async def log_requests(request: Request, call_next):
"""Middleware to log all incoming requests and responses"""
req_id = request.headers.get("x-request-id", str(uuid.uuid4()))
# Log incoming request
logger.info(
">>> [%s] %s %s from %s",
req_id,
request.method,
request.url.path,
request.client.host if request.client else "unknown",
)
try:
# Process request
response = await call_next(request)
# Log response
logger.info(
"<<< [%s] %s %s completed with status %d",
req_id,
request.method,
request.url.path,
response.status_code,
)
return response
except Exception as e:
# Log errors
logger.exception(
"!!! [%s] %s %s failed with error: %s",
req_id,
request.method,
request.url.path,
str(e),
)
raise
###############################################################################
# FastAPI lifecycle
###############################################################################
@app.on_event("startup")
async def on_startup() -> None:
global encode_session, prefill_session, decode_session
timeout = aiohttp.ClientTimeout(total=100_000)
connector = aiohttp.TCPConnector(limit=0, force_close=False)
encode_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
if app.state.p_urls:
# only setup if prefill instance(s) exist
prefill_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
decode_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
@app.on_event("shutdown")
async def on_shutdown() -> None:
global encode_session, prefill_session, decode_session
if encode_session:
await encode_session.close()
if prefill_session:
await prefill_session.close()
if decode_session:
await decode_session.close()
###############################################################################
# Core forwarding
###############################################################################
async def forward_non_stream(
req_data: dict, req_id: str, e_urls: list[str], p_url: str, d_url: str
) -> dict:
try:
# Step 1: Process through Encoder instance (if has MM input)
await fanout_encoder_primer(req_data, e_urls, req_id)
# Step 2: Process through Prefill instance
req_data = await maybe_prefill(req_data, p_url, req_id)
# Step 3: Process through Decode instance
logger.info("[%s] Forwarding to decode: %s", req_id, d_url)
headers = {"x-request-id": req_id}
# Non-streaming response
async with decode_session.post(
f"{d_url}/v1/chat/completions", json=req_data, headers=headers
) as resp:
resp.raise_for_status()
return await resp.json()
except HTTPException:
raise
except Exception as e:
logger.exception("[%s] Error in forward_non_stream: %s", req_id, str(e))
raise HTTPException(status_code=500, detail=f"Proxy error: {str(e)}") from e
async def forward_stream(
req_data: dict, req_id: str, e_urls: list[str], p_url: str, d_url: str
) -> AsyncIterator[str]:
try:
# Step 1: Process through Encoder instance (if has MM input)
await fanout_encoder_primer(req_data, e_urls, req_id)
# Step 2: Process through Prefill instance
req_data = await maybe_prefill(req_data, p_url, req_id)
# Step 3: Process through Decode instance
logger.info("[%s] Starting streaming from decode: %s", req_id, d_url)
headers = {"x-request-id": req_id}
# Streaming response
async with decode_session.post(
f"{d_url}/v1/chat/completions",
json=req_data,
headers=headers,
) as resp:
resp.raise_for_status()
async for chunk in resp.content.iter_chunked(1024):
if chunk:
yield chunk.decode("utf-8", errors="ignore")
logger.info("[%s] Streaming completed", req_id)
except HTTPException:
logger.exception("[%s] HTTPException in forward_stream", req_id)
raise
except Exception as e:
logger.exception("[%s] Error in forward_stream: %s", req_id, str(e))
raise HTTPException(
status_code=500, detail=f"Proxy streaming error: {str(e)}"
) from e
###############################################################################
# Public routes
###############################################################################
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
try:
req_data = await request.json()
req_id = request.headers.get("x-request-id", str(uuid.uuid4()))
e_urls = app.state.e_urls # we want the full list for fan-out
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
d_url = random.choice(app.state.d_urls)
is_streaming = req_data.get("stream", False)
if is_streaming:
return StreamingResponse(
forward_stream(req_data, req_id, e_urls, p_url, d_url),
media_type="text/event-stream",
)
result = await forward_non_stream(req_data, req_id, e_urls, p_url, d_url)
return JSONResponse(content=result)
except HTTPException:
raise
except Exception as e:
logger.exception("Error in chat_completions endpoint: %s", str(e))
raise HTTPException(
status_code=500, detail=f"Request processing error: {str(e)}"
) from e
@app.get("/v1/models")
async def list_models():
async with decode_session.get(f"{app.state.d_urls[0]}/v1/models") as resp:
resp.raise_for_status()
return await resp.json()
@app.get("/health")
async def health_check():
async def healthy(urls):
if not urls:
return "empty"
for u in urls:
try:
async with encode_session.get(f"{u}/health") as resp:
resp.raise_for_status()
except Exception:
return "unhealthy"
return "healthy"
e_status, p_status, d_status = await asyncio.gather(
healthy(app.state.e_urls), healthy(app.state.p_urls), healthy(app.state.d_urls)
)
overall_healthy = all(
status != "unhealthy" for status in (e_status, p_status, d_status)
)
status_code = 200 if overall_healthy else 503
return JSONResponse(
{
"proxy": "healthy",
"encode_cluster": e_status,
"prefill_cluster": p_status,
"decode_cluster": d_status,
},
status_code=status_code,
)
###############################################################################
# Simple profiler fan-out (unchanged except for sessions)
###############################################################################
async def _post_if_available(
session: aiohttp.ClientSession,
url: str,
payload: dict,
headers: dict,
) -> dict | None:
"""
POST `payload` to `url`.
Returns
-------
• The decoded JSON body on success (2xx)
• None if the endpoint does not exist (404)
• Raises for anything else.
"""
try:
resp = await session.post(url, json=payload, headers=headers)
if resp.status == 404: # profiling disabled on that server
logger.warning("Profiling endpoint missing on %s", url)
return None
resp.raise_for_status()
return await resp.json(content_type=None)
except aiohttp.ClientResponseError as exc:
# Pass 404 through the branch above, re-raise everything else
if exc.status == 404:
logger.warning("Profiling endpoint missing on %s", url)
return None
raise
except Exception:
# Network errors etc.: propagate
raise
async def _profile_cmd(cmd: str, payload: dict, e_url: str, p_url: str, d_url: str):
"""
Fire & forget to both clusters, tolerate 404.
"""
headers = {"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY', '')}"}
encode_task = _post_if_available(
encode_session, f"{e_url}/{cmd}_profile", payload, headers
)
prefill_task = (
_post_if_available(prefill_session, f"{p_url}/{cmd}_profile", payload, headers)
if p_url is not None
else asyncio.sleep(0)
)
decode_task = _post_if_available(
decode_session, f"{d_url}/{cmd}_profile", payload, headers
)
encode_res, prefill_res, decode_res = await asyncio.gather(
encode_task, prefill_task, decode_task
)
# If *all* clusters said “I don’t have that route”, surface an error
if encode_res is prefill_res is decode_res is None:
raise HTTPException(
status_code=503,
detail="Profiling endpoints are disabled on all clusters",
)
return {
"encode": encode_res, # may be None
"prefill": prefill_res, # may be None
"decode": decode_res, # may be None
}
@app.post("/start_profile")
async def start_profile(request: Request):
body = await request.json()
# TODO: handle multi urls properly
e_url = random.choice(app.state.e_urls)
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
d_url = random.choice(app.state.d_urls)
return await _profile_cmd("start", body, e_url, p_url, d_url)
@app.post("/stop_profile")
async def stop_profile(request: Request):
body = await request.json()
# TODO: handle multi urls properly
e_url = random.choice(app.state.e_urls)
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
d_url = random.choice(app.state.d_urls)
return await _profile_cmd("stop", body, e_url, p_url, d_url)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument(
"--encode-servers-urls",
required=True,
help='Comma-separated encode URLs ("http://e1:8001,http://e2:8001")',
)
parser.add_argument(
"--prefill-servers-urls",
required=True,
help=(
'Comma-separated prefill URLs ("http://p1:8003,http://p2:8004") ',
'to enable E->P->D, set "disable" or "none" to enable E->PD',
),
)
parser.add_argument(
"--decode-servers-urls",
required=True,
help='Comma-separated decode URLs ("http://d1:8005,http://d2:8006")',
)
args = parser.parse_args()
app.state.e_urls = [
u.strip() for u in args.encode_servers_urls.split(",") if u.strip()
]
app.state.d_urls = [
u.strip() for u in args.decode_servers_urls.split(",") if u.strip()
]
# handle prefill instances
if args.prefill_servers_urls.lower() in ("disable", "none", ""):
app.state.p_urls = []
logger.info(
"Disaggregated prefill phase explicitly disabled by user. Running E + PD..."
)
else:
app.state.p_urls = [
u.strip() for u in args.prefill_servers_urls.split(",") if u.strip()
]
logger.info("Disaggregated prefill phase is enabled. Running E + P + D...")
logger.info("Proxy listening on %s:%s", args.host, args.port)
logger.info("Encode servers: %s", app.state.e_urls)
logger.info("Prefill instances %s", app.state.p_urls)
logger.info("Decode servers: %s", app.state.d_urls)
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="info",
loop="uvloop",
access_log=True,
)
......@@ -8,6 +8,7 @@ import torch
from vllm.config import (
CacheConfig,
ECTransferConfig,
KVTransferConfig,
ModelConfig,
SchedulerConfig,
......@@ -20,6 +21,9 @@ from vllm.multimodal.inputs import (
PlaceholderRange,
)
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.utils.hashing import sha256
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (
......@@ -872,7 +876,10 @@ def _step_until_done(
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
# We should be in the decode phase now.
assert num_scheduled_tokens == 1
assert len(output.kv_connector_metadata.requests) == 0
if scheduler.connector is not None:
assert len(output.kv_connector_metadata.requests) == 0
if scheduler.ec_connector is not None:
assert len(output.ec_connector_metadata.mm_datas) == 0
ecos = scheduler.update_from_output(output, model_runner_output)[0]
all_done = True
for eco in ecos.outputs:
......@@ -1066,7 +1073,10 @@ def test_external_prefix_cache_metrics():
assert external_stats.preempted_requests == 0
def test_kv_connector_unable_to_allocate():
@pytest.mark.parametrize(
"use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")]
)
def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role):
"""
Test whether scheduler with KVConnector is able to handle
unable to allocate (run out of blocks in allocate_slots().
......@@ -1080,6 +1090,9 @@ def test_kv_connector_unable_to_allocate():
use_kv_connector=True,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
# encoder connector should not affect test results
use_ec_connector=use_ec_connector,
ec_role=ec_role,
)
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
......@@ -1148,7 +1161,10 @@ def test_kv_connector_unable_to_allocate():
assert len(scheduler.waiting) == 0
def test_kv_connector_handles_preemption():
@pytest.mark.parametrize(
"use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")]
)
def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
"""
Test whether scheduler with KVConnector is able to handle
unable to allocate (run out of blocks in allocate_slots().
......@@ -1163,6 +1179,9 @@ def test_kv_connector_handles_preemption():
use_kv_connector=True,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
# encoder connector should not affect test results
use_ec_connector=use_ec_connector,
ec_role=ec_role,
)
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
......@@ -1379,6 +1398,8 @@ def create_scheduler_with_priority(
block_size: int = 16,
max_model_len: int | None = None,
num_speculative_tokens: int | None = None,
use_ec_connector: bool = False,
ec_role: str | None = None,
) -> Scheduler:
"""Create scheduler with priority policy enabled.
......@@ -1439,12 +1460,23 @@ def create_scheduler_with_priority(
model="ngram", num_speculative_tokens=num_speculative_tokens
)
ec_transfer_config = (
ECTransferConfig(
ec_connector="ECSharedStorageConnector",
ec_role=ec_role,
ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"},
)
if use_ec_connector
else None
)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
ec_transfer_config=ec_transfer_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
......@@ -1465,16 +1497,23 @@ def create_scheduler_with_priority(
)
_none_hash_initialized = False
def create_requests_with_priority(
num_requests: int,
priorities: list[int],
arrival_times: list[float] | None = None,
num_tokens: int = 10,
mm_hashes_list: list[list[str]] | None = None,
mm_positions: list[list[PlaceholderRange]] | None = None,
max_tokens: int = 16,
stop_token_ids: list[int] | None = None,
prompt_logprobs: int | None = None,
starting_idx: int = 0,
same_prompt: bool = False,
block_size: int = 16,
req_ids: list[str] | None = None,
):
"""Create requests with specified priorities and arrival times."""
assert len(priorities) == num_requests
......@@ -1483,6 +1522,12 @@ def create_requests_with_priority(
else:
arrival_times = [float(i) for i in range(num_requests)]
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(sha256)
_none_hash_initialized = True
block_hasher = get_request_block_hasher(block_size, sha256)
sampling_params = SamplingParams(
ignore_eos=False,
max_tokens=max_tokens,
......@@ -1490,29 +1535,70 @@ def create_requests_with_priority(
prompt_logprobs=prompt_logprobs,
)
requests = []
if mm_hashes_list is not None:
# NOTE: allow manual input; some mm items can have the same identifier
# no. of mm_hashes and mm_positions for each request should be identical
assert mm_positions is not None, (
"mm_positions must be provided when mm_hashes_list is provided"
)
assert len(mm_hashes_list) == len(mm_positions) == num_requests
assert [len(h) for h in mm_hashes_list] == [len(p) for p in mm_positions]
# Since same identifier would imply they are identical encoder output
# Verify mm items with identical identifier are having mm_position.length
seen_hashes: dict[str, int] = {}
if req_ids:
assert len(req_ids) == num_requests
else:
req_ids = [f"{i + starting_idx}" for i in range(num_requests)]
for i in range(num_requests):
mm_features = []
if mm_positions is not None:
mm_position = mm_positions[i]
for j, position in enumerate(mm_position):
for j, position in enumerate(
mm_positions[i] if mm_positions is not None else []
):
if mm_hashes_list is not None:
identifier = mm_hashes_list[i][j]
# Verify if position length is identical
position_length = position.length
if identifier in seen_hashes:
assert seen_hashes[identifier] == position_length, (
f"mm_hash '{identifier}' has inconsistent position lengths: "
f"previously {seen_hashes[identifier]}, now {position_length} "
f"at request {i}, position {j}"
)
else:
seen_hashes[identifier] = position_length
else:
# Unique dummy hash for each mm item
identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image",
)
mm_features.append(mm_feature)
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image",
)
mm_features.append(mm_feature)
prompt_token_ids = (
[starting_idx] * num_tokens
if same_prompt
else [i + starting_idx] * num_tokens
)
request = Request(
request_id=f"{i + starting_idx}",
prompt_token_ids=[i + starting_idx] * num_tokens,
request_id=req_ids[i],
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
mm_features=mm_features if mm_features else None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=arrival_times[i],
priority=priorities[i],
block_hasher=block_hasher,
)
requests.append(request)
return requests
......@@ -1999,7 +2085,12 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
assert len(scheduler.waiting) == 1
def test_priority_scheduling_preemption_and_resumption_when_out_of_kv():
@pytest.mark.parametrize(
"use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")]
)
def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
use_ec_connector, ec_role
):
"""Test that priority scheduling preempts lower priority requests
when out of KV cache space."""
# Create scheduler with very limited memory to force preemption
......@@ -2009,6 +2100,9 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv():
num_blocks=5, # Can hold 64 tokens (first block is null)
block_size=16, # Standard block size
use_kv_connector=True,
# encoder connector should not affect test results
use_ec_connector=use_ec_connector,
ec_role=ec_role,
)
# Create a request and schedule it
......@@ -2168,3 +2262,976 @@ def _validate_chunked_prefill_settings_for_encoder_decoder(
assert scheduler_config.disable_chunked_mm_input is not expect_enabled
if is_encoder_decoder and not expect_enabled:
assert scheduler_config.long_prefill_token_threshold == 0
# ==============================================================================
# EPD (Encoder-Prefill-Decode) Encoder-cache-specific tests start
# NOTE: In E->P->D disagg case, both KV and EC Connector works in P instance
# Unless specify, the existence of KV Connector should not affect any test results
# ==============================================================================
def _assert_right_encoder_cache_allocated(
scheduler: Scheduler,
hashes_to_check: list[str] | None = None,
requests: list[Request] | None = None,
expected_total_allocated: int | None = None,
):
"""Check whether encoder cache is allocated correctly."""
encoder_cache_manager = scheduler.encoder_cache_manager
# Verify encoder cache manager exists
assert encoder_cache_manager is not None, "Encoder cache manager should exist"
# Verify number of cache
if expected_total_allocated is not None:
assert len(encoder_cache_manager.cached) == expected_total_allocated
if expected_total_allocated == 0:
return
# Verify each request with MM data is in cache
cached_hashes = set(encoder_cache_manager.cached.keys())
if hashes_to_check:
missed_hashes = set(hashes_to_check) - cached_hashes
assert not missed_hashes, (
f"Miss hashes: {missed_hashes} "
f"Existing encoder cache: {encoder_cache_manager.cached}"
)
for req in requests if requests is not None else []:
if req.mm_features:
mm_hashes = [f.identifier for f in req.mm_features]
req_hashes = set(mm_hashes) # unique hashes set
missed_hashes = req_hashes - cached_hashes
assert not missed_hashes, (
f"Miss hashes in cache for request {req.request_id}: {missed_hashes} "
f"Existing encoder cache: {encoder_cache_manager.cached}"
)
def _assert_right_ec_connector_metadata(
output: SchedulerOutput,
mm_features_list: list[MultiModalFeatureSpec],
):
"""Verify that ECConnector metadata EXACTLY matches the input MM data"""
# Get the connector metadata
metadata = output.ec_connector_metadata
# Create lookup dictionaries for efficient access
metadata_dict = {mm_data.mm_hash: mm_data for mm_data in metadata.mm_datas}
# Check all required identifiers exist in metadata; and no extra
# In ECSharedStorageConnector format
# NOTE: even having same identifier, the mm_features can be different
# since their mm_position can be in different offsets, etc
identifiers_dict = {f.identifier for f in mm_features_list}
assert set(metadata_dict.keys()) == identifiers_dict
# Verify the info matches
for i, mm_feature in enumerate(mm_features_list):
identifier = mm_feature.identifier
assert metadata_dict[identifier].mm_hash == identifier
assert metadata_dict[identifier].num_token == mm_feature.mm_position.length
def _assert_right_encoder_inputs(
output: SchedulerOutput,
check_exist: bool | None = True,
requests: list[Request] | None = None,
expected_encoder_inputs: list[list[int]] | None = None,
expected_total_reqs: int | None = None,
):
"""Verify that requests/mm_hashes should (not) in scheduled encoder input
If check_exist is False, this function returns True
if requests are NOT in encoder inputs"""
# Get the scheduled encoder inputs
# NOTE: scheduled_encoder_inputs is a dictionary with request id as key
scheduled_encoder_inputs = output.scheduled_encoder_inputs
# Check if scheduled_encoder_inputs is empty as expected
if expected_total_reqs is not None:
assert len(scheduled_encoder_inputs) == expected_total_reqs
if expected_total_reqs == 0:
return
# Number of expected enocder inputs should match number of requests
if expected_encoder_inputs:
assert check_exist and requests is not None # only support expect input exist
assert len(requests) == len(expected_encoder_inputs)
# Check request (not) exist as expected
for i, request in enumerate(requests if requests is not None else []):
assert (request.request_id in scheduled_encoder_inputs) is check_exist, (
f"Request {request.id} presence mismatch: expected {check_exist}, "
f"got {request.id in scheduled_encoder_inputs}"
)
if expected_encoder_inputs:
scheduled_encoder_input = scheduled_encoder_inputs[request.request_id]
assert scheduled_encoder_input == expected_encoder_inputs[i]
def test_scheduler_no_ec_connector_by_default():
"""Test scheduler doesn't have EC connector by default."""
scheduler = create_scheduler()
assert scheduler.ec_connector is None
@pytest.mark.parametrize("use_kv_connector", [False, True])
def test_ec_connector_text_only_request(use_kv_connector):
"""Test text-only requests don't allocate encoder cache."""
scheduler = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
use_kv_connector=use_kv_connector,
use_ec_connector=True,
ec_role="ec_consumer",
)
NUM_PROMPT_TOKENS = 100
# Create text-only request (no mm_positions)
requests = create_requests(
num_requests=1,
num_tokens=NUM_PROMPT_TOKENS,
)
assert not requests[0].mm_features # No MM data
scheduler.add_request(requests[0])
output = scheduler.schedule()
# Should schedule
assert len(output.scheduled_new_reqs) == 1
# Scheduled tokens should equal prompt tokens exactly
scheduled = output.num_scheduled_tokens[requests[0].request_id]
assert scheduled == NUM_PROMPT_TOKENS, (
f"Text-only should schedule {NUM_PROMPT_TOKENS}, got {scheduled}"
)
# Encoder cache should be empty
_assert_right_encoder_cache_allocated(scheduler, expected_total_allocated=0)
# ECConnector should carry no metadata
_assert_right_ec_connector_metadata(output, mm_features_list=[])
# Scheduled encoder input should be empty; no mm to compute
_assert_right_encoder_inputs(output, expected_total_reqs=0)
@pytest.mark.parametrize("use_kv_connector", [False, True])
def test_ec_connector_cache_hit_external_load(use_kv_connector):
"""Test ec_consumer loads from external cache when hit.
A normal basic operation for EPD disaggrgation"""
scheduler = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
enable_prefix_caching=True,
# kv connector should not effect test results
use_kv_connector=use_kv_connector,
use_ec_connector=True,
ec_role="ec_consumer",
)
# Create MM request
NUM_TOKENS = 200 # NOTE: includes mm tokens
NUM_ENCODER_TOKENS = 100
mm_hashes_list = [["hash_test1"]]
mm_positions = [[PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS)]]
request = create_requests(
num_requests=1,
num_tokens=NUM_TOKENS,
mm_hashes_list=mm_hashes_list,
mm_positions=mm_positions,
)[0]
# Mock cache hit - encoder cache exists externally
scheduler.ec_connector.has_caches = Mock(return_value=[True])
scheduler.ec_connector.update_state_after_alloc = Mock(
wraps=scheduler.ec_connector.update_state_after_alloc
)
scheduler.add_request(request)
output = scheduler.schedule()
# Should schedule prompt tokens
scheduled_tokens = output.num_scheduled_tokens[request.request_id]
assert scheduled_tokens == NUM_TOKENS
# Should called update_state_after_alloc for external load
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request, 0)
# Encoder cache should contain mm items from request
_assert_right_encoder_cache_allocated(scheduler, requests=[request])
# ECConnector should carry metadata of request
_assert_right_ec_connector_metadata(output, mm_features_list=request.mm_features)
# Scheduled encoder input should be empty; no mm to compute
_assert_right_encoder_inputs(output, expected_total_reqs=0)
@pytest.mark.parametrize("use_kv_connector", [False, True])
def test_ec_connector_cache_miss_computes_locally(use_kv_connector):
"""Test consumer can compute encoder locally when cache miss (fallback)."""
# encoder cache itself if it doesn't receive it from external storage
scheduler = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
enable_prefix_caching=True,
use_kv_connector=use_kv_connector,
use_ec_connector=True,
ec_role="ec_consumer",
)
# Verify consumer role
assert scheduler.ec_connector is not None
assert not scheduler.ec_connector.is_producer
# Create MM request
request_mm_missed = create_requests(
num_requests=1,
num_tokens=200, # Total (including 100 MM)
mm_positions=[[PlaceholderRange(offset=0, length=100)]], # 100 MM tokens
)[0]
# Mock cache miss - encoder cache doesn't exist externally
scheduler.ec_connector.has_caches = Mock(return_value=[False])
scheduler.add_request(request_mm_missed)
output = scheduler.schedule()
# SCHEDULER should decide to compute encoder locally (fallback)
assert len(output.scheduled_new_reqs) == 1
# Should schedule full prompt tokens
scheduled_tokens = output.num_scheduled_tokens[request_mm_missed.request_id]
assert scheduled_tokens == 200, (
f"Expected 200 tokens on cache miss, got {scheduled_tokens}"
)
# Encoder cache should contain mm items from request
_assert_right_encoder_cache_allocated(scheduler, requests=[request_mm_missed])
# ECConnector should carry no metadata (missed cache)
_assert_right_ec_connector_metadata(output, mm_features_list=[])
# Scheduled encoder input contain mm for request_mm_missed
_assert_right_encoder_inputs(
output,
requests=[request_mm_missed],
expected_encoder_inputs=[[0]], # index 0 of the mm item
expected_total_reqs=1,
)
# Then MODEL_RUNNER will execute the encoder and cache the result
@pytest.mark.parametrize("use_kv_connector", [False, True])
def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
"""Test consumer with partial cache hit (local & connector) with 2 requests."""
scheduler = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
enable_prefix_caching=True,
use_kv_connector=use_kv_connector,
use_ec_connector=True,
ec_role="ec_consumer",
)
# Create MM request
NUM_TOKENS_1 = 300 # NOTE: includes mm tokens
NUM_ENCODER_TOKENS_1 = 50
mm_hashes_list_1 = [["hash1_A", "hash1_B", "hash1_A", "hash1_F"]]
mm_positions_1 = [
[
PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_1),
PlaceholderRange(offset=100, length=NUM_ENCODER_TOKENS_1),
PlaceholderRange(offset=200, length=NUM_ENCODER_TOKENS_1),
PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_1),
]
]
# Create request with 4 MM items, with 2 identical items
request1 = create_requests(
num_requests=1,
num_tokens=NUM_TOKENS_1,
mm_hashes_list=mm_hashes_list_1,
mm_positions=mm_positions_1,
max_tokens=1, # For simplicity
)[0]
# Mock partial cache hit: 1st and 3rd missing, 2nd and 4th exist
scheduler.ec_connector.has_caches = Mock(return_value=[False, True, False, True])
scheduler.ec_connector.update_state_after_alloc = Mock(
wraps=scheduler.ec_connector.update_state_after_alloc
)
scheduler.add_request(request1)
output = scheduler.schedule()
# Should schedule all tokens
scheduled_tokens = output.num_scheduled_tokens[request1.request_id]
assert scheduled_tokens == NUM_TOKENS_1
# Encoder cache should contain all mm items from request
_assert_right_encoder_cache_allocated(scheduler, requests=[request1])
# Should have called update_state_after_alloc for external load
scheduler.ec_connector.update_state_after_alloc.assert_called()
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# ECConnector should carry metadata for 2nd and 4th mm item
_assert_right_ec_connector_metadata(
output, mm_features_list=[request1.mm_features[1], request1.mm_features[3]]
)
# Should schedule ONLY 1 encoder input (index 0), no repeat for identical items
_assert_right_encoder_inputs(
output,
requests=[request1],
expected_encoder_inputs=[[0]], # index 0 of the mm item ONLY
expected_total_reqs=1,
)
# Simulate model execution 1 step
model_output = ModelRunnerOutput(
req_ids=[request1.request_id],
req_id_to_index={request1.request_id: 0},
sampled_token_ids=[[100]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# request1 is finished after outputing 1 token
# Finish request
scheduler.finish_requests(request1.request_id, RequestStatus.FINISHED_LENGTH_CAPPED)
# Create another request with 4 MM items
NUM_TOKENS_2 = 400
NUM_ENCODER_TOKENS_2 = 50
mm_hashes_list_2 = [["hash1_C", "hash1_D", "hash1_E", "hash1_A"]]
mm_positions_2 = [
[
PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_2),
PlaceholderRange(offset=100, length=NUM_ENCODER_TOKENS_2),
PlaceholderRange(offset=200, length=NUM_ENCODER_TOKENS_2),
PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_2),
]
]
request2 = create_requests(
num_requests=1,
num_tokens=NUM_TOKENS_2,
mm_hashes_list=mm_hashes_list_2,
mm_positions=mm_positions_2,
max_tokens=1, # For simplicity
)[0]
# Mock partial cache hit: only hash1_A and hash1_C exist in connector
scheduler.ec_connector.has_caches = Mock(return_value=[True, False, False, True])
scheduler.add_request(request2)
output = scheduler.schedule()
# Check
# Should schedule all tokens
scheduled_tokens = output.num_scheduled_tokens[request2.request_id]
assert scheduled_tokens == 400
# Encoder cache should contain all mm items from request2
_assert_right_encoder_cache_allocated(scheduler, requests=[request2])
# Should call update_state_after_alloc for hash1_C, ONLY
# hash1_A should not be loaded from connector
# since it's computed in last request & exist in local cache
# Order of getting encoder cache should be: local cache -> connector-> compute
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 0)
scheduler.ec_connector.update_state_after_alloc.assert_called_once()
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# ECConnector should carry metadata for hash1_C only (index 0)
_assert_right_ec_connector_metadata(
output, mm_features_list=[request2.mm_features[0]]
)
# Should schedule 2 encoder input hash1_D and hash1_E (index 1, 2)
_assert_right_encoder_inputs(
output,
requests=[request2],
expected_encoder_inputs=[[1, 2]],
expected_total_reqs=1,
)
@pytest.mark.parametrize("cache_exist", ["local", "connector_only", "no_where"])
@pytest.mark.parametrize("use_kv_connector", [False, True])
def test_ec_connector_schedule_multiple_requests(cache_exist, use_kv_connector):
scheduler = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
max_num_seqs=10, # allow multiple requests
max_num_batched_tokens=2048,
enable_prefix_caching=True,
use_kv_connector=use_kv_connector,
use_ec_connector=True,
ec_role="ec_consumer",
)
mm_hashes_list = [[f"hash_{i}"] for i in range(10)]
mm_positions = [[PlaceholderRange(offset=i, length=100)] for i in range(10)]
requests = create_requests(
num_requests=10,
num_tokens=200,
mm_hashes_list=mm_hashes_list,
mm_positions=mm_positions,
)
for request in requests:
scheduler.add_request(request)
# Set up to test different encoder cache exsistence scenario after preemption
# Order of getting encoder cache should be: local cache -> connector-> compute
scheduler.ec_connector.update_state_after_alloc = Mock(
wraps=scheduler.ec_connector.update_state_after_alloc
)
if cache_exist == "local":
# Allocate cache to cache manager manually to mimick
for req in requests:
scheduler.encoder_cache_manager.allocate(req, 0)
else:
# Make sure local encoder cache empty
scheduler.encoder_cache_manager.cached = {}
if cache_exist == "connector_only":
# Cache exist in ec_connector
scheduler.ec_connector.has_caches = Mock(return_value=[True])
elif cache_exist == "no_where":
scheduler.ec_connector.has_caches = Mock(return_value=[False])
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
for req_id, num_tokens in output.num_scheduled_tokens.items():
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
## Encoder-cache-specific checks:
# mm_hashes of requests exist in cache after scheduling for all scenario
_assert_right_encoder_cache_allocated(scheduler, requests=requests)
# Should only call update_state_after_alloc when loaded externally
if cache_exist == "connector_only":
scheduler.ec_connector.update_state_after_alloc.assert_called_with(
requests[-1], 0
)
# Concat mm_features for the 10 requests together
mm_features_list = [feature for req in requests for feature in req.mm_features]
# Check metadata should contain mm data for all 10 requests
_assert_right_ec_connector_metadata(output, mm_features_list=mm_features_list)
else:
scheduler.ec_connector.update_state_after_alloc.assert_not_called()
# ECConnector should carry no metadata
_assert_right_ec_connector_metadata(output, mm_features_list=[])
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# Should only schedule encoder input when cache is not found anywhere
if cache_exist == "no_where":
_assert_right_encoder_inputs(
output,
requests=requests,
expected_encoder_inputs=[[0] for _ in range(10)],
expected_total_reqs=10,
)
else:
_assert_right_encoder_inputs(output, expected_total_reqs=0)
@pytest.mark.parametrize("use_kv_connector", [False, True])
def test_ec_connector_unable_to_allocate(use_kv_connector):
"""
Test whether scheduler with ECConnector is able to handle
unable to allocate (run out of blocks).
"""
# Setup Scheduler With Mock External Cache Hit.
BLOCK_SIZE = 4
NUM_BLOCKS = 10
scheduler = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
enable_prefix_caching=True,
use_kv_connector=use_kv_connector,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
use_ec_connector=True,
ec_role="ec_consumer",
)
# Mock ec_connector load external cache behavior
scheduler.ec_connector.has_caches = Mock(return_value=[True])
scheduler.ec_connector.update_state_after_alloc = Mock(
wraps=scheduler.ec_connector.update_state_after_alloc
)
# Create two requests. The second request will not be able to
# allocate slots because it will not have enough blocks.
NUM_REQUESTS = 2
NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE
MAX_TOKENS = 2
requests = create_requests(
num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
mm_hashes_list=[["hash_1"], ["hash_2"]],
mm_positions=[
[PlaceholderRange(offset=1, length=10)] for _ in range(NUM_REQUESTS)
],
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE,
)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
# Setup MODEL_RUNNER_OUTPUT to be run in _step_until_done later
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# Just one request should be running.
output = scheduler.schedule()
scheduled_tokens = output.num_scheduled_tokens[scheduler.running[0].request_id]
assert scheduled_tokens == NUM_TOKENS
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# Should have called update_state_after_alloc for external load
scheduler.ec_connector.update_state_after_alloc.assert_called_with(
scheduler.running[0], 0
)
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# All memory should be freed, with one request waiting.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
# Just one request should be running.
output = scheduler.schedule()
scheduled_tokens = output.num_scheduled_tokens[scheduler.running[0].request_id]
assert scheduled_tokens == NUM_TOKENS
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# update_state_after_alloc should be called for loading external cache
scheduler.ec_connector.update_state_after_alloc.assert_called_with(
scheduler.running[0], 0
)
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# All memory should be freed, with no requests waiting / running.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 0
@pytest.mark.parametrize("cache_exist", ["local", "connector_only", "no_where"])
@pytest.mark.parametrize("use_kv_connector", [False, True])
def test_priority_scheduling_ec_connector_preemption_and_resumption(
cache_exist, use_kv_connector
):
"""Test that priority scheduling preempts lower priority requests
when out of KV cache space."""
# Create scheduler with very limited memory to force preemption
scheduler = create_scheduler_with_priority(
model="llava-hf/llava-1.5-7b-hf",
enable_prefix_caching=True,
max_num_seqs=2, # allow multiple requests
# kv connector should not effect test results
use_kv_connector=use_kv_connector,
num_blocks=15, # can hold 244 tokens with 14 blocks (first block is null)
block_size=16, # standard block size
use_ec_connector=True,
ec_role="ec_consumer",
)
# Mock cache hit: Both cache exist in connector (at E->PD initially)
scheduler.ec_connector.has_caches = Mock(return_value=[True])
scheduler.ec_connector.update_state_after_alloc = Mock(
wraps=scheduler.ec_connector.update_state_after_alloc
)
# Create a request and schedule it (and to be preempted)
request_low = create_requests_with_priority(
num_requests=1,
priorities=[1],
arrival_times=[0.0],
num_tokens=94,
mm_hashes_list=[["hash_low"]],
# NOTE: this test only preempt the last block.
# Setting mm_position at the last block can force to recompute encoding
mm_positions=[[PlaceholderRange(offset=82, length=10)]],
starting_idx=0,
)[0]
scheduler.add_request(request_low)
# 1st schedule
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
scheduled_tokens = output.num_scheduled_tokens[request_low.request_id]
assert scheduled_tokens == 94
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 1
## Encoder-cache-specific checks:
# Encoder cache should contain mm items from request
_assert_right_encoder_cache_allocated(scheduler, requests=[request_low])
# Verify update_state_after_alloc called (external load)
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request_low, 0)
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# ECConnector should carry metadata of request
_assert_right_ec_connector_metadata(
output, mm_features_list=request_low.mm_features
)
# Scheduled encoder input should be empty; no mm to compute
_assert_right_encoder_inputs(output, expected_total_reqs=0)
# Simulate model execution - 1st decode
model_output = ModelRunnerOutput(
req_ids=[request_low.request_id],
req_id_to_index={request_low.request_id: 0},
sampled_token_ids=[[100]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Create a high priority request and schedule it
request_high = create_requests_with_priority(
num_requests=1,
priorities=[0],
arrival_times=[1.0],
num_tokens=128,
mm_hashes_list=[["hash_high"]],
mm_positions=[[PlaceholderRange(offset=1, length=10)]],
max_tokens=2,
starting_idx=1,
)[0]
scheduler.add_request(request_high)
# 2nd schedule
output = scheduler.schedule()
# KV cache should be full at this point
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_cached_reqs.num_reqs == 1
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 2
## Encoder-cache-specific checks:
# Encoder cache should contain mm items from request
_assert_right_encoder_cache_allocated(scheduler, requests=[request_high])
# Verify update_state_after_alloc called (external load)
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request_high, 0)
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# ECConnector should carry metadata of request
_assert_right_ec_connector_metadata(
output, mm_features_list=request_high.mm_features
)
# Scheduled encoder input should be empty; no mm to compute
_assert_right_encoder_inputs(output, expected_total_reqs=0)
# Simulate model execution - 2nd decode
requests = [request_low, request_high]
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[[100] for _ in requests],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# 3rd schedule - - this should trigger preemption
# req_low needs 96 tokens = 6 blocks
# req_high needs 129 tokens = 9 blocks
# so doesn't fit in 14 blocks.
output = scheduler.schedule()
# Should have preempted req_low
assert len(output.scheduled_new_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 1
assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id
assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 1
## Encoder-cache-specific checks:
# request_high is in decode phase now
# ECConnector should carry no metadata
_assert_right_ec_connector_metadata(output, mm_features_list=[])
# Scheduled encoder input should be empty; no mm to compute
_assert_right_encoder_inputs(output, expected_total_reqs=0)
# Simulate model execution - 3rd decode, after req_low was preempted
requests = [request_low, request_high]
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[[100], [100, 200]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# Finish the requests to make room for the preempted requests to resume
# req_high is finished after outputing 2 tokens
scheduler.update_from_output(output, model_output)
scheduler.finish_requests(
request_high.request_id, RequestStatus.FINISHED_LENGTH_CAPPED
)
# Set up to test different encoder cache exsistence scenario after preemption
# Order of getting encoder cache should be: local cache -> connector-> compute
# By default, the cache should still exist in local in this test case
if cache_exist != "local":
# Make local encoder cache empty
scheduler.encoder_cache_manager.cached = {}
if cache_exist == "connector_only":
# Cache exist in ec_connector
scheduler.ec_connector.has_caches = Mock(return_value=[True])
elif cache_exist == "no_where":
scheduler.ec_connector.has_caches = Mock(return_value=[False])
# 4th Schedule - this should trigger req_low resumption from waiting
output = scheduler.schedule()
scheduled_cached_reqs = output.scheduled_cached_reqs
resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption
assert len(output.scheduled_new_reqs) == 0
assert scheduled_cached_reqs.num_reqs == 1
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 1
# Preempted request resumed in scheduled_cached_reqs
assert len(resumed_from_preemption) == 1
assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1
assert resumed_from_preemption[0]
assert scheduled_cached_reqs.req_ids[0] == request_low.request_id
assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None
## Resumed tokens include 94 prompt tokens and 2 decoded tokens
assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 96
assert scheduled_cached_reqs.resumed_req_token_ids[0][95] == 100
assert scheduler.running[0].request_id == request_low.request_id
assert request_high.request_id in output.finished_req_ids
## Encoder-cache-specific checks:
# mm_hash of request_low exists in cache after scheduling for all scenario
_assert_right_encoder_cache_allocated(scheduler, requests=[request_low])
# Should only call update_state_after_alloc when loaded externally
if cache_exist == "connector_only":
scheduler.ec_connector.update_state_after_alloc.assert_called_with(
request_low, 0
)
_assert_right_ec_connector_metadata(
output, mm_features_list=request_low.mm_features
)
else:
scheduler.ec_connector.update_state_after_alloc.assert_not_called()
# ECConnector should carry no metadata
_assert_right_ec_connector_metadata(output, mm_features_list=[])
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# Should only schedule encoder input when cache is not found anywhere
if cache_exist == "no_where":
_assert_right_encoder_inputs(
output,
requests=[request_low],
expected_encoder_inputs=[[0]],
expected_total_reqs=1,
)
else:
_assert_right_encoder_inputs(output, expected_total_reqs=0)
@pytest.mark.parametrize("use_kv_connector", [False, True])
def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connector):
"""
Scenario:
- Encoder cache size: 32
- Request A: 1 feature (12 tokens) → NOT cached remotely.
- Request B: 3 features (3 x 10 tokens) → ALL cached remotely.
Steps:
1. Schedule Request A (locally uses 12 tokens).
2. Schedule Request B (remote cache) - only schedule 1st and 2nd
3. Free A's cache, then schedule B again (continuation) - schedule 3rd image
"""
scheduler = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
max_num_batched_tokens=1024,
enable_prefix_caching=True,
use_kv_connector=use_kv_connector,
block_size=16,
num_blocks=11, # Can hold 160 tokens (first block is null)
use_ec_connector=True,
ec_role="ec_consumer",
)
# Limit the number of availiable slots of EncoderCacheManager
scheduler.encoder_cache_manager = EncoderCacheManager(cache_size=32)
# Create MM request1
NUM_TOKENS_1 = 50 # NOTE: includes mm tokens
NUM_ENCODER_TOKENS_1 = 12
mm_hashes_list_1 = [["hash1_1"]]
mm_positions_1 = [[PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_1)]]
request1 = create_requests(
num_requests=1,
num_tokens=NUM_TOKENS_1,
mm_hashes_list=mm_hashes_list_1,
mm_positions=mm_positions_1,
max_tokens=1, # For simplicity
req_ids=["req1"],
)[0]
# Create MM request1 with 3 MM items
NUM_TOKENS_2 = 40
NUM_ENCODER_TOKENS_2 = 10
mm_hashes_list_2 = [["hash2_1", "hash2_2", "hash2_3"]]
mm_positions_2 = [
[
PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_2),
PlaceholderRange(offset=12, length=NUM_ENCODER_TOKENS_2),
PlaceholderRange(offset=24, length=NUM_ENCODER_TOKENS_2),
]
]
request2 = create_requests(
num_requests=1,
num_tokens=NUM_TOKENS_2,
mm_hashes_list=mm_hashes_list_2,
mm_positions=mm_positions_2,
max_tokens=10,
req_ids=["req2"],
)[0]
# Mock cache hit: MM of request1 NOT cached remotely, request2 cached remotely
scheduler.ec_connector.has_caches = Mock(
side_effect=lambda req: [True, True, True] if req == request2 else [False]
)
scheduler.ec_connector.update_state_after_alloc = Mock(
wraps=scheduler.ec_connector.update_state_after_alloc
)
scheduler.add_request(request1)
scheduler.add_request(request2)
output = scheduler.schedule()
# Now, since encoder cache manager can only store 32 tokens
# It should allocated mm item hash1_1, hash2_1 and hash2_2
scheduled_tokens = output.num_scheduled_tokens[request1.request_id]
assert scheduled_tokens == NUM_TOKENS_1
assert scheduler.get_num_unfinished_requests() == 2
# Encoder cache should contain mm item from request1
_assert_right_encoder_cache_allocated(
scheduler, hashes_to_check=["hash1_1", "hash2_1", "hash2_2"]
)
# request2's 2nd mm item is the last call of update_state_after_alloc
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 1)
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# ECConnector should carry metadata of hash2_1 and hash2_2 ONLY
_assert_right_ec_connector_metadata(
output, mm_features_list=[request2.mm_features[0], request2.mm_features[1]]
)
# Should schedule ONLY 1 encoder input
_assert_right_encoder_inputs(
output,
requests=[request1],
expected_encoder_inputs=[[0]], # index 0 of the mm item of request1
expected_total_reqs=1,
)
# Simulate model execution 1 step
model_output = ModelRunnerOutput(
req_ids=[request1.request_id, request2.request_id],
req_id_to_index={request1.request_id: 0, request2.request_id: 1},
sampled_token_ids=[[100], [121]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# request1 is finished after outputing 1 token
# Finish request
scheduler.finish_requests(request1.request_id, RequestStatus.FINISHED_LENGTH_CAPPED)
assert scheduler.get_num_unfinished_requests() == 1
# Schedule again; Now request1's encoder cache should be freed
# -> hash2_3 can be scheduled and allocated
output = scheduler.schedule()
# Check
# Should schedule all tokens
scheduled_tokens = output.num_scheduled_tokens[request2.request_id]
print(f"Hero: scheduled_tokens for req2: {scheduled_tokens}")
print(f"hero: num_scheduled_tokens 2: {output.num_scheduled_tokens}")
# Encoder cache should contain all mm items from request2
_assert_right_encoder_cache_allocated(scheduler, requests=[request2])
# request2's 3rd mm item is the ONLY call of update_state_after_alloc
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 2)
scheduler.ec_connector.update_state_after_alloc.assert_called_once()
scheduler.ec_connector.update_state_after_alloc.reset_mock()
# ECConnector should carry metadata for hash2_3 ONLY
_assert_right_ec_connector_metadata(
output, mm_features_list=[request2.mm_features[2]]
)
# Should schedule no encoder input
_assert_right_encoder_inputs(
output,
expected_total_reqs=0,
)
# ==============================================================================
# EPD (Encoder-Prefill-Decode) Encoder-cache-specific tests end
# ==============================================================================
......@@ -5,6 +5,7 @@ import torch
from vllm.config import (
CacheConfig,
ECTransferConfig,
KVTransferConfig,
ModelConfig,
SchedulerConfig,
......@@ -46,6 +47,8 @@ def create_scheduler(
num_speculative_tokens: int | None = None,
skip_tokenizer_init: bool = False,
async_scheduling: bool = False,
use_ec_connector: bool = False,
ec_role: str | None = None,
) -> Scheduler | AsyncScheduler:
"""Create scheduler under test.
......@@ -107,12 +110,23 @@ def create_scheduler(
model="ngram", num_speculative_tokens=num_speculative_tokens
)
ec_transfer_config = (
ECTransferConfig(
ec_connector="ECSharedStorageConnector",
ec_role=ec_role,
ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"},
)
if use_ec_connector
else None
)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
ec_transfer_config=ec_transfer_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
......@@ -140,12 +154,14 @@ _none_hash_initialized = False
def create_requests(
num_requests: int,
num_tokens: int = 10,
mm_hashes_list: list[list[str]] | None = None,
mm_positions: list[list[PlaceholderRange]] | None = None,
max_tokens: int = 16,
stop_token_ids: list[int] | None = None,
prompt_logprobs: int | None = None,
same_prompt: bool = False,
block_size: int = 16,
req_ids: list[str] | None = None,
) -> list[Request]:
global _none_hash_initialized
if not _none_hash_initialized:
......@@ -160,25 +176,58 @@ def create_requests(
prompt_logprobs=prompt_logprobs,
)
requests = []
if mm_hashes_list is not None:
# NOTE: allow manual input; some mm items can have the same identifier
# no. of mm_hashes and mm_positions for each request should be identical
assert mm_positions is not None, (
"mm_positions must be provided when mm_hashes_list is provided"
)
assert len(mm_hashes_list) == len(mm_positions) == num_requests
assert [len(h) for h in mm_hashes_list] == [len(p) for p in mm_positions]
# Since same identifier would imply they are identical encoder output
# Verify mm items with identical identifier are having mm_position.length
seen_hashes: dict[str, int] = {}
if req_ids:
assert len(req_ids) == num_requests
else:
req_ids = [f"{i}" for i in range(num_requests)]
for i in range(num_requests):
mm_features = []
if mm_positions is not None:
mm_position = mm_positions[i]
for j, position in enumerate(mm_position):
# Dummy hash for each mm item should be unique
# since encoder cache tracks entries by hash
for j, position in enumerate(
mm_positions[i] if mm_positions is not None else []
):
if mm_hashes_list is not None:
identifier = mm_hashes_list[i][j]
# Verify if position length is identical
position_length = position.length
if identifier in seen_hashes:
assert seen_hashes[identifier] == position_length, (
f"mm_hash '{identifier}' has inconsistent position lengths: "
f"previously {seen_hashes[identifier]}, now {position_length} "
f"at request {i}, position {j}"
)
else:
seen_hashes[identifier] = position_length
else:
# Unique dummy hash for each mm item
identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image",
)
mm_features.append(mm_feature)
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image",
)
mm_features.append(mm_feature)
prompt_token_ids = [0] * num_tokens if same_prompt else [i] * num_tokens
request = Request(
request_id=f"{i}",
request_id=req_ids[i],
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
......
# EPD Correctness Test
This test verifies that EPD (Encoder-Prefill-Decode) disaggregation produces identical outputs to a baseline single instance.
## What It Tests
- **Baseline**: Single vLLM instance serving a multimodal model
- **EPD (1E+1PD)**: 1 Encoder + 1 Prefill-Decode instance
- **Baseline (1P+1D)**: 1 Prefill + 1 Decode instance
- **EPD (1E+1P+1D)**: 1 Encoder + 1 Prefill + 1 Decode instance
The test ensures that disaggregated encoding produces **identical** outputs to the baseline.
Note that currently PD disaggregation set up may give slightly different results from a single instance. Therefore, we need the result from 1P+1D as the baseline for 1E+1P+1D
Please refer to [Disaggregated Encoder Feature](../../../docs/features/disagg_encoder.md) for the detailed explanation for the EPD features.
## Files
- `run_epd_correctness_test.sh` - Main test script (starts all instances and runs tests)
- `test_epd_correctness.py` - Python test script (compares outputs)
## Usage
### Multimodal Prompts (Default)
```bash
cd vllm
./tests/v1/ec_connector/integration/run_epd_correctness_test.sh
```
This runs the test with actual multimodal (image) prompts.
### Text-Only Prompts
```bash
cd vllm
USE_MM_PROMPTS=0 ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh
```
This runs a quick test with text-only prompts to verify the setup works.
### Custom Configuration
```bash
# Use specific GPUs
GPU_E=0 GPU_PD=1 GPU_P=1 GPU_D=2 bash ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh
# Use specific ports
ENDPOINT_PORT=10001 bash ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh
# Use specific model
MODEL="Qwen/Qwen2.5-VL-3B-Instruct" bash ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh
# Use specific storage path
EC_SHARED_STORAGE_PATH="/tmp/my_ec_cache" bash ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh
```
## How It Works
### Step 1: Baseline
1. Start single vLLM instance on GPU
2. Run test prompts (multimodal or text-only)
3. Save outputs to `.vllm_epd_baseline.txt`
4. Shutdown instance
### Step 2: EPD (1E + 1PD)
1. Clear encoder cache storage
2. Start instances and proxy
3. Run same test prompts
4. Assert outputs match baseline exactly
5. Shutdown instances
### Step 3: EPD (1E + 1P + 1D)
1. Clear encoder cache storage
2. Start instances and proxy
3. Run same test prompts
4. Assert outputs match baseline exactly
5. Shutdown instances
## Test Scenarios
### Multimodal Prompts (--use_mm_prompts)
Tests encoder cache transfer:
- Single image query
- Multiple images in one request
- Mixed image and text
- Image with detailed questions
### Text-Only Prompts (default)
Quick sanity check:
- Simple text queries
- Text-only explanations
- Verifies proxy routing works
## Expected Behavior
### ✅ Test Passes When
- All disagg outputs match baseline outputs exactly
- No errors during instance startup
- Encoder cache is properly saved and loaded
- Proxy correctly routes requests
### ❌ Test Fails When
- Outputs differ between baseline and disagg
- Server startup fails
- Encoder cache not found (should fallback to local execution)
- Proxy routing errors
## Notes
- The test uses deterministic generation (`temperature=0.0`, `seed=42`)
- Encoder cache should enable exact output reproduction
- Test cleans up all instances and cache files after completion
- Safe to run multiple times (idempotent)
- We setup the PD disagg part with NixlConnector. Please read details about EPD in `examples/online_serving/disaggregated_encoder/README.md`
## Requirements
- Multiple GPUs (3 for 1E+1P+1D, 2 for 1E+1PD, 1 for baseline)
- 1E+1P+1D is runnable with 2 GPU by assign E and P on the same GPU now.
- Multimodal model (e.g., Qwen2.5-VL-3B-Instruct)
- Internet access (for accessing vllm test images)
## Debugging
### Check Logs
Logs and baseline output are saved in `/tmp/` by default.
Can be customized by changing the environment variables.
### Check Encoder Cache
```bash
# Verify cache files are created
ls -la $EC_SHARED_STORAGE_PATH/
# Should see directories with mm_hash names
# Each containing encoder_cache.safetensors
```
### Manual Testing
Run individual components:
```bash
# Baseline only
python test_epd_correctness.py \
--service_url http://localhost:8000 \
--model_name Qwen/Qwen2.5-VL-3B-Instruct \
--mode baseline \
--baseline_file test_output.txt \
--use_mm_prompts
# Disagg only (requires baseline output file!)
python test_epd_correctness.py \
--service_url http://localhost:8000 \
--model_name Qwen/Qwen2.5-VL-3B-Instruct \
--mode disagg \
--baseline_file test_output.txt \
--use_mm_prompts
```
#!/bin/bash
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# EPD (Encoder-Prefill-Decode) Correctness Test
#
# This script tests that EPD disaggregation produces the same outputs as baseline.
# It runs:
# 1. Baseline: Single vLLM instance
# 2. EPD: 1E + 1PD setup
# 3. Baseline for (E + P + D): 1P + 1D vLLM instances disagg
# 4. EPD: 1E + 1P + 1D setup
# For GPU usage
# set -xe
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
# Model to test
MODEL="${MODEL:-Qwen/Qwen2.5-VL-3B-Instruct}"
# Set 1 to use multimodal prompts; else to use text-only
USE_MM_PROMPTS="${USE_MM_PROMPTS:-1}"
MM_FLAG=""
if [ $USE_MM_PROMPTS = "1" ]; then
MM_FLAG="--use_mm_prompts"
fi
# GPU configuration
GPU_E="${GPU_E:-0}"
GPU_P="${GPU_P:-1}"
GPU_D="${GPU_D:-2}"
GPU_SINGLE="${GPU_SINGLE:-$GPU_P}"
GPU_PD="${GPU_PD:-$GPU_P}"
# Port
ENCODE_PORT="${ENCODE_PORT:-19534}"
PREFILL_PORT="${PREFILL_PORT:-19535}"
DECODE_PORT="${DECODE_PORT:-19536}"
PREFILL_DECODE_PORT="${PREFILL_DECODE_PORT:-19537}"
ENDPOINT_PORT="${ENDPOINT_PORT:-10001}"
# Storage path for encoder cache
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache_test}"
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-600}"
# Output file for baseline comparison and logs
LOG_PATH="${LOG_PATH:-/tmp}"
BASELINE_FILE="${BASELINE_FILE:-/tmp/vllm_baseline.txt}"
BASELINE_PD_FILE="${BASELINE_PD_FILE:-/tmp/vllm_epd_baseline.txt}"
mkdir -p $LOG_PATH
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
# Wait for server to be ready
wait_for_server() {
local port=$1
timeout "$TIMEOUT_SECONDS" bash -c "
until curl -s localhost:${port}/v1/chat/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Cleanup function
cleanup_instances() {
echo "Cleaning up any running vLLM instances..."
pkill -f "vllm serve" || true
pkill -f "disagg_epd_proxy.py" || true
sleep 2
}
# Function to run baseline (single instance)
run_baseline() {
echo "================================"
echo "Running BASELINE (single instance)"
echo "================================"
cleanup_instances
rm -rf "$EC_SHARED_STORAGE_PATH"
local PORT=$ENDPOINT_PORT
# Start baseline instance
echo "Starting baseline instance on GPU $GPU_SINGLE, port $PORT"
CUDA_VISIBLE_DEVICES="$GPU_SINGLE" vllm serve "$MODEL" \
--port $PORT \
--enforce-eager \
--gpu-memory-utilization 0.7 \
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
> $LOG_PATH/baseline.log 2>&1 &
local BASELINE_PID=$!
# Wait for baseline to start
echo "Waiting for baseline instance to start..."
wait_for_server $PORT
curl http://127.0.0.1:$PORT/v1/models
echo ""
# Run test in baseline mode
echo "Running baseline..."
python "${GIT_ROOT}/tests/v1/ec_connector/integration/test_epd_correctness.py" \
--service_url "http://localhost:$PORT" \
--model_name "$MODEL" \
--mode baseline \
--baseline_file "$BASELINE_FILE" \
$MM_FLAG
# Cleanup baseline
echo "Stopping baseline instance..."
kill $BASELINE_PID 2>/dev/null || true
sleep 2
cleanup_instances
}
# Function to run EPD with 1E + 1PD
run_epd_1e_1pd() {
echo "================================"
echo "Running EPD (1E + 1PD)"
echo "================================"
cleanup_instances
rm -rf "$EC_SHARED_STORAGE_PATH"
mkdir -p "$EC_SHARED_STORAGE_PATH"
local ENCODE_PORT=$ENCODE_PORT
local PREFILL_DECODE_PORT=$PREFILL_DECODE_PORT
local PROXY_PORT=$ENDPOINT_PORT
declare -a PIDS=()
# Start encoder instance
echo "Starting encoder instance on GPU $GPU_E, port $ENCODE_PORT"
CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
--port $ENCODE_PORT \
--enforce-eager \
--gpu-memory-utilization 0.01 \
--enable-request-id-headers \
--no-enable-prefix-caching \
--max-num-batched-tokens 114688 \
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_role": "ec_producer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}' \
> $LOG_PATH/1e1pd_encoder.log 2>&1 &
PIDS+=($!)
# Start prefill+decode instance
echo "Starting PD instance on GPU $GPU_PD, port $PREFILL_DECODE_PORT"
CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \
--port $PREFILL_DECODE_PORT \
--enforce-eager \
--gpu-memory-utilization 0.7 \
--enable-request-id-headers \
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_role": "ec_consumer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}' \
> $LOG_PATH/1e1pd_pd.log 2>&1 &
PIDS+=($!)
# Wait for instances to start
echo "Waiting for encoder instance..."
wait_for_server $ENCODE_PORT
echo "Waiting for PD instance..."
wait_for_server $PREFILL_DECODE_PORT
# Start proxy
echo "Starting EPD proxy on port $PROXY_PORT"
python "${GIT_ROOT}/examples/online_serving/disaggregated_encoder/disagg_epd_proxy.py" \
--host "0.0.0.0" \
--port $PROXY_PORT \
--encode-servers-urls "http://localhost:$ENCODE_PORT" \
--prefill-servers-urls "disable" \
--decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT" \
> $LOG_PATH/1e1pd_proxy.log 2>&1 &
PIDS+=($!)
# Wait for proxy
echo "Waiting for proxy..."
wait_for_server $PROXY_PORT
curl http://127.0.0.1:$PROXY_PORT/v1/models
curl http://127.0.0.1:$PROXY_PORT/health
echo ""
echo "All EPD (1E+1PD) services are up!"
# Run test in disagg mode
echo "Running EPD (1E+1PD) correctness test..."
python "${GIT_ROOT}/tests/v1/ec_connector/integration/test_epd_correctness.py" \
--service_url "http://localhost:$PROXY_PORT" \
--model_name "$MODEL" \
--mode disagg \
--baseline_file "$BASELINE_FILE" \
$MM_FLAG
# Cleanup
echo "✓✓ 1E+1PD Correctness Test finished"
echo "Stopping EPD (1E+1PD) instances..."
for pid in "${PIDS[@]}"; do
kill $pid 2>/dev/null || true
done
sleep 2
cleanup_instances
}
# Function to run baseline for 1E + 1P + 1D (PD disagg)
run_baseline_1p_1d() {
echo "================================"
echo "Running PD BASELINE (1P + 1D)"
echo "================================"
cleanup_instances
rm -rf "$EC_SHARED_STORAGE_PATH"
mkdir -p "$EC_SHARED_STORAGE_PATH"
local PREFILL_PORT=$PREFILL_PORT
local DECODE_PORT=$DECODE_PORT
local PROXY_PORT=$ENDPOINT_PORT
declare -a PIDS=()
# Start prefill instance
echo "Starting prefill instance on GPU $GPU_P, port $PREFILL_PORT"
CUDA_VISIBLE_DEVICES="$GPU_P" \
VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \
vllm serve "$MODEL" \
--port $PREFILL_PORT \
--enforce-eager \
--gpu-memory-utilization 0.7 \
--enable-request-id-headers \
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--kv-transfer-config '{
"kv_connector": "NixlConnector",
"kv_role": "kv_producer"
}' \
> $LOG_PATH/1p1d_prefill.log 2>&1 &
PIDS+=($!)
# Start decode instance
echo "Starting decode instance on GPU $GPU_D, port $DECODE_PORT"
CUDA_VISIBLE_DEVICES="$GPU_D" \
VLLM_NIXL_SIDE_CHANNEL_PORT=6000 \
vllm serve "$MODEL" \
--port $DECODE_PORT \
--enforce-eager \
--gpu-memory-utilization 0.7 \
--enable-request-id-headers \
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--kv-transfer-config '{
"kv_connector": "NixlConnector",
"kv_role": "kv_consumer"
}' \
> $LOG_PATH/1p1d_decode.log 2>&1 &
PIDS+=($!)
# Wait for instances to start
echo "Waiting for prefill instance..."
wait_for_server $PREFILL_PORT
echo "Waiting for decode instance..."
wait_for_server $DECODE_PORT
# Start proxy
echo "Starting EPD proxy on port $PROXY_PORT"
python "${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py" \
--host "0.0.0.0" \
--port $PROXY_PORT \
--prefiller-ports $PREFILL_PORT \
--decoder-ports $DECODE_PORT \
> $LOG_PATH/1p1d_proxy.log 2>&1 &
PIDS+=($!)
# Wait for proxy
echo "Waiting for proxy..."
wait_for_server $PROXY_PORT
curl http://127.0.0.1:$PROXY_PORT/healthcheck
echo ""
echo "All PD (1P+1D) services are up!"
# Run test in baseline mode
echo "Running PD disagg baseline..."
python "${GIT_ROOT}/tests/v1/ec_connector/integration/test_epd_correctness.py" \
--service_url "http://localhost:$PROXY_PORT" \
--model_name "$MODEL" \
--mode baseline_pd \
--baseline_file "$BASELINE_PD_FILE" \
$MM_FLAG
# Cleanup
echo "Stopping PD (1P+1D) instances..."
for pid in "${PIDS[@]}"; do
kill $pid 2>/dev/null || true
done
sleep 2
cleanup_instances
}
# Function to run EPD with 1E + 1P + 1D
run_epd_1e_1p_1d() {
echo "================================"
echo "Running EPD (1E + 1P + 1D)"
echo "================================"
cleanup_instances
rm -rf "$EC_SHARED_STORAGE_PATH"
mkdir -p "$EC_SHARED_STORAGE_PATH"
local ENCODE_PORT=$ENCODE_PORT
local PREFILL_PORT=$PREFILL_PORT
local DECODE_PORT=$DECODE_PORT
local PROXY_PORT=$ENDPOINT_PORT
declare -a PIDS=()
# Start encoder instance
echo "Starting encoder instance on GPU $GPU_E, port $ENCODE_PORT"
CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
--port $ENCODE_PORT \
--enforce-eager \
--gpu-memory-utilization 0.01 \
--enable-request-id-headers \
--no-enable-prefix-caching \
--max-num-batched-tokens 114688 \
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_role": "ec_producer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}' \
> $LOG_PATH/1e1p1d_encoder.log 2>&1 &
PIDS+=($!)
# Start prefill instance
echo "Starting prefill instance on GPU $GPU_P, port $PREFILL_PORT"
CUDA_VISIBLE_DEVICES="$GPU_P" \
VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \
vllm serve "$MODEL" \
--port $PREFILL_PORT \
--enforce-eager \
--gpu-memory-utilization 0.7 \
--enable-request-id-headers \
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_role": "ec_consumer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}' \
--kv-transfer-config '{
"kv_connector": "NixlConnector",
"kv_role": "kv_producer"
}' \
> $LOG_PATH/1e1p1d_prefill.log 2>&1 &
PIDS+=($!)
# Start decode instance
echo "Starting decode instance on GPU $GPU_D, port $DECODE_PORT"
CUDA_VISIBLE_DEVICES="$GPU_D" \
VLLM_NIXL_SIDE_CHANNEL_PORT=6000 \
vllm serve "$MODEL" \
--port $DECODE_PORT \
--enforce-eager \
--gpu-memory-utilization 0.7 \
--enable-request-id-headers \
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--kv-transfer-config '{
"kv_connector": "NixlConnector",
"kv_role": "kv_consumer"
}' \
> $LOG_PATH/1e1p1d_decode.log 2>&1 &
PIDS+=($!)
# Wait for instances to start
echo "Waiting for encoder instance..."
wait_for_server $ENCODE_PORT
echo "Waiting for prefill instance..."
wait_for_server $PREFILL_PORT
echo "Waiting for decode instance..."
wait_for_server $DECODE_PORT
# Start proxy
echo "Starting EPD proxy on port $PROXY_PORT"
python "${GIT_ROOT}/examples/online_serving/disaggregated_encoder/disagg_epd_proxy.py" \
--host "0.0.0.0" \
--port $PROXY_PORT \
--encode-servers-urls "http://localhost:$ENCODE_PORT" \
--prefill-servers-urls "http://localhost:$PREFILL_PORT" \
--decode-servers-urls "http://localhost:$DECODE_PORT" \
> $LOG_PATH/1e1p1d_proxy.log 2>&1 &
PIDS+=($!)
# Wait for proxy
echo "Waiting for proxy..."
wait_for_server $PROXY_PORT
curl http://127.0.0.1:$PROXY_PORT/v1/models
curl http://127.0.0.1:$PROXY_PORT/health
echo ""
echo "All EPD (1E+1P+1D) services are up!"
# Run test in disagg mode
echo "Running EPD (1E+1P+1D) correctness test..."
python "${GIT_ROOT}/tests/v1/ec_connector/integration/test_epd_correctness.py" \
--service_url "http://localhost:$PROXY_PORT" \
--model_name "$MODEL" \
--mode disagg \
--baseline_file "$BASELINE_PD_FILE" \
$MM_FLAG
# Cleanup
echo "✓✓ 1E+1P+1D Correctness Test finished"
echo "Stopping EPD (1E+1P+1D) instances..."
for pid in "${PIDS[@]}"; do
kill $pid 2>/dev/null || true
done
sleep 2
cleanup_instances
}
# Main execution
echo "================================"
echo "EPD Correctness Test Suite"
echo "Model: $MODEL"
echo "================================"
# Step 1: Run baseline
run_baseline
# Step 2: Test 1E + 1PD
run_epd_1e_1pd
# Step 3: Test baseline 1P + 1D
run_baseline_1p_1d
# Step 4: Test 1E + 1P + 1D
run_epd_1e_1p_1d
# Cleanup output file
rm -f "$BASELINE_FILE"
rm -f "$BASELINE_PD_FILE"
echo "================================"
echo "✓✓ All EPD correctness tests finished!"
echo "================================"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
EPD Correctness Test
Tests that EPD (Encoder-Prefill-Decode) disaggregation produces the same
outputs as a baseline single instance.
Usage:
# Baseline mode (saves outputs):
python test_epd_correctness.py \
--service_url http://localhost:8000 \
--model_name Qwen/Qwen2.5-VL-3B-Instruct \
--mode baseline \
--baseline_file .vllm_epd_baseline.txt
# Disagg mode (compares outputs):
python test_epd_correctness.py \
--service_url http://localhost:8000 \
--model_name Qwen/Qwen2.5-VL-3B-Instruct \
--mode disagg \
--baseline_file .vllm_epd_baseline.txt
"""
import argparse
import json
import os
import time
import openai
import requests
from vllm.assets.image import ImageAsset
from vllm.multimodal.utils import encode_image_base64
MAX_OUTPUT_LEN = 256
# Sample prompts with multimodal content
image_1 = ImageAsset("stop_sign").pil_image.resize((1280, 720))
image_2 = ImageAsset("cherry_blossom").pil_image.resize((1280, 720))
image_local_path = f"{os.path.dirname(os.path.abspath(__file__))}/hato.jpg"
SAMPLE_PROMPTS_MM: list[dict] = [
{
"messages": [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image;base64,{encode_image_base64(image_1)}"
},
},
{"type": "text", "text": "What's in this image?"},
],
}
],
"description": "Single image query",
},
{
"messages": [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image;base64,{encode_image_base64(image_2)}"
},
},
{
"type": "image_url",
"image_url": {"url": f"file://{image_local_path}"},
},
{"type": "text", "text": "Describe these 2 images in detail."},
],
}
],
"description": "2 images with detailed query",
},
]
# Text-only prompts for mixed testing
SAMPLE_PROMPTS_TEXT: list[dict] = [
{
"messages": [{"role": "user", "content": "What is the capital of France?"}],
"description": "Simple text-only query",
},
{
"messages": [
{"role": "user", "content": "Explain quantum computing in simple terms."}
],
"description": "Text-only explanation request",
},
]
def check_vllm_server(url: str, timeout=5, retries=10) -> bool:
"""Check if the vLLM server is ready.
Args:
url: The URL to check (usually /health or /healthcheck endpoint)
timeout: Timeout in seconds for each request
retries: Number of retries if the server is not ready
Returns:
True if the server is ready, False otherwise
"""
for attempt in range(retries):
try:
response = requests.get(url, timeout=timeout)
if response.status_code == 200:
print(f"Server is ready at {url}")
return True
else:
print(
f"Attempt {attempt + 1}/{retries}: Server returned "
f"status code {response.status_code}"
)
except requests.exceptions.RequestException as e:
print(f"Attempt {attempt + 1}/{retries}: Error connecting: {e}")
time.sleep(2) # Wait before retrying
return False
def run_chat_completion(
base_url: str,
model_name: str,
messages: list,
max_tokens: int = MAX_OUTPUT_LEN,
) -> str:
"""Run a chat completion request.
Args:
base_url: Base URL of the vLLM server
model_name: Name of the model
messages: Messages for chat completion
max_tokens: Maximum tokens to generate
Returns:
Generated text content
"""
client = openai.OpenAI(api_key="EMPTY", base_url=base_url)
completion = client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=max_tokens,
temperature=0.0,
seed=42,
)
return completion.choices[0].message.content
def main():
"""Main test function."""
parser = argparse.ArgumentParser(
description="EPD correctness test - compare disagg vs baseline"
)
parser.add_argument(
"--service_url",
type=str,
required=True,
help="The vLLM service URL (e.g., http://localhost:8000)",
)
parser.add_argument(
"--model_name",
type=str,
required=True,
help="Model name",
)
parser.add_argument(
"--mode",
type=str,
default="baseline",
choices=["baseline", "baseline_pd", "disagg"],
help="Mode: baseline/baseline_pd (saves outputs) or disagg (compares outputs)",
)
parser.add_argument(
"--baseline_file",
type=str,
default=".vllm_epd_baseline.txt",
help="File to save/load baseline outputs",
)
parser.add_argument(
"--use_mm_prompts",
action="store_true",
help="Use multimodal prompts (default: use text-only for quick testing)",
)
args = parser.parse_args()
print(f"Service URL: {args.service_url}")
print(f"Model: {args.model_name}")
print(f"Mode: {args.mode}")
print(f"Output file: {args.baseline_file}")
print(f"Use MM prompts: {args.use_mm_prompts}")
# Determine health check endpoint
if args.mode == "baseline":
health_check_url = f"{args.service_url}/health"
elif args.mode == "baseline_pd":
# Nixl toy proxy use /healthcheck
health_check_url = f"{args.service_url}/healthcheck"
else:
# Disagg EPD proxy uses /health
health_check_url = f"{args.service_url}/health"
if not os.path.exists(args.baseline_file):
raise ValueError(
f"In disagg mode, the output file {args.baseline_file} from "
"baseline does not exist. Run baseline mode first."
)
# Check if server is ready
if not check_vllm_server(health_check_url):
raise RuntimeError(f"vLLM server at {args.service_url} is not ready!")
# Select prompts to use
if args.use_mm_prompts:
test_prompts = SAMPLE_PROMPTS_MM
print("Using multimodal prompts")
else:
test_prompts = SAMPLE_PROMPTS_TEXT
print("Using text-only prompts for quick testing")
# Run completions
service_url = f"{args.service_url}/v1"
output_strs = {}
for i, prompt_data in enumerate(test_prompts):
print(
f"\nRunning prompt {i + 1}/{len(test_prompts)}: {
prompt_data['description']
}"
)
output_str = run_chat_completion(
base_url=service_url,
model_name=args.model_name,
messages=prompt_data["messages"],
max_tokens=MAX_OUTPUT_LEN,
)
# Use description as key for comparison
key = prompt_data["description"]
output_strs[key] = output_str
print(f"Output: {output_str}")
if args.mode in ("baseline", "baseline_pd"):
# Baseline mode: Save outputs
print(f"\nSaving baseline outputs to {args.baseline_file}")
try:
with open(args.baseline_file, "w") as json_file:
json.dump(output_strs, json_file, indent=4)
print("✅ Baseline outputs saved successfully")
except OSError as e:
print(f"Error writing to file: {e}")
raise
else:
# Disagg mode: Load and compare outputs
print(f"\nLoading baseline outputs from {args.baseline_file}")
baseline_outputs = None
try:
with open(args.baseline_file) as json_file:
baseline_outputs = json.load(json_file)
except OSError as e:
print(f"Error reading from file: {e}")
raise
# Verify outputs match
print("\nComparing disagg outputs with baseline...")
assert isinstance(baseline_outputs, dict), "Baseline outputs should be a dict"
assert len(baseline_outputs) == len(output_strs), (
f"Length mismatch: baseline has {len(baseline_outputs)}, "
f"disagg has {len(output_strs)}"
)
all_match = True
for key, baseline_output in baseline_outputs.items():
assert key in output_strs, f"{key} not in disagg outputs"
disagg_output = output_strs[key]
if baseline_output == disagg_output:
print(f"✅ {key}: MATCH")
else:
print(f"❌ {key}: MISMATCH")
print(f" Baseline: {baseline_output}")
print(f" Disagg: {disagg_output}")
all_match = False
assert all_match, "❌❌Disagg outputs do not match baseline!❌❌"
if all_match:
print("\n✅ All outputs match! Test PASSED")
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for ECSharedStorageConnector.
"""
import os
from unittest.mock import Mock, patch
import pytest
import safetensors
import torch
from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorRole
from vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import (
ECSharedStorageConnector,
ECSharedStorageConnectorMetadata,
MMMeta,
)
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
from vllm.v1.core.sched.output import SchedulerOutput
# ------------------ Mock Classes ------------------ #
class MockRequest:
def __init__(self, request_id, mm_hashes: list[str], token_counts: list[int]):
assert len(mm_hashes) == len(token_counts)
self.request_id = request_id
self._token_counts = token_counts
self.mm_features = []
for i, mm_hash in enumerate(mm_hashes):
feature = MultiModalFeatureSpec(
data=None,
modality="image",
identifier=mm_hash,
mm_position=PlaceholderRange(offset=0, length=self._token_counts[i]),
)
self.mm_features.append(feature)
def get_num_encoder_tokens(self, input_id: int) -> int:
assert input_id < len(self._token_counts)
return self._token_counts[input_id]
@pytest.fixture
def temp_storage(tmp_path):
"""Fixture providing temporary storage path."""
return str(tmp_path)
@pytest.fixture
def mock_vllm_config_producer(temp_storage):
"""Fixture providing mock VllmConfig for producer role."""
config = Mock(spec=VllmConfig)
config.ec_transfer_config = Mock()
config.ec_transfer_config.get_from_extra_config = Mock(return_value=temp_storage)
config.ec_transfer_config.is_ec_producer = True
return config
@pytest.fixture
def mock_vllm_config_consumer(temp_storage):
"""Fixture providing mock VllmConfig for consumer role."""
config = Mock(spec=VllmConfig)
config.ec_transfer_config = Mock()
config.ec_transfer_config.get_from_extra_config = Mock(return_value=temp_storage)
config.ec_transfer_config.is_ec_producer = False
return config
@pytest.fixture
def mock_request_with_3_mm():
"""Fixture providing mock Request with 3 multimodal items."""
request_id = "test_req_123"
mm_hashes = ["img_hash_1", "img_hash_2", "img_hash_3"]
token_counts = [100, 150, 200]
request = MockRequest(request_id, mm_hashes, token_counts)
return request
# ------------------ Unit Tests ------------------ #
class TestECSharedStorageConnectorBasics:
"""Test basic EC connector functionality."""
def test_initialization_producer(self, mock_vllm_config_producer, temp_storage):
"""Test connector initializes correctly as producer."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
assert connector.role == ECConnectorRole.SCHEDULER
assert connector.is_producer
assert connector._storage_path == temp_storage
assert connector._mm_datas_need_loads == {}
def test_initialization_consumer(self, mock_vllm_config_consumer, temp_storage):
"""Test connector initializes correctly as consumer."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
assert connector.role == ECConnectorRole.WORKER
assert not connector.is_producer
assert connector._storage_path == temp_storage
def test_role_assignment(self, mock_vllm_config_producer):
"""Test role is correctly assigned."""
scheduler_connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
worker_connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER,
)
assert scheduler_connector.role == ECConnectorRole.SCHEDULER
assert worker_connector.role == ECConnectorRole.WORKER
class TestCacheExistence:
"""Test cache existence checking using has_caches() API."""
def test_has_caches_all_exist_3_items(
self,
mock_vllm_config_producer,
mock_vllm_config_consumer,
mock_request_with_3_mm,
):
"""Test has_caches returns True when all 3 caches exist."""
# Test for producer first
producer = ECSharedStorageConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
# Create cache files using save_caches (proper way)
encoder_cache: dict[str, torch.Tensor] = {}
for mm_feature in mock_request_with_3_mm.mm_features:
mm_hash = mm_feature.identifier
encoder_cache[mm_hash] = torch.randn(10, 768)
producer.save_caches(encoder_cache, mm_hash)
# Test using has_caches API
producer_result = producer.has_caches(mock_request_with_3_mm)
# Assert
assert len(producer_result) == 3
assert all(producer_result), f"Expected all True, got {producer_result}"
# Also test consumer can check if cache exists
consumer = ECSharedStorageConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.SCHEDULER,
)
# Test using has_caches API
consumer_result = consumer.has_caches(mock_request_with_3_mm)
# Assert
assert len(consumer_result) == 3
assert all(consumer_result), f"Expected all True, got {consumer_result}"
def test_has_caches_none_exist(
self, mock_vllm_config_producer, mock_request_with_3_mm
):
"""Test has_caches returns False when no caches exist."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
# Test without creating any files
result = connector.has_caches(mock_request_with_3_mm)
# Assert
assert len(result) == 3
assert not any(result), f"Expected all False, got {result}"
def test_has_caches_partial_exist(
self, mock_vllm_config_producer, mock_request_with_3_mm
):
"""Test has_caches with some caches existing (1 of 3)."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
# Create only the second cache file
mm_hash_second = mock_request_with_3_mm.mm_features[1].identifier
encoder_cache = {mm_hash_second: torch.randn(10, 768)}
connector.save_caches(encoder_cache, mm_hash_second)
# Test
result = connector.has_caches(mock_request_with_3_mm)
# Assert
assert len(result) == 3
assert not result[0] # First doesn't exist
assert result[1] # Second exists
assert not result[2] # Third doesn't exist
class TestStateManagement:
"""Test connector state management."""
def test_update_state_after_alloc_3_items(
self, mock_vllm_config_producer, mock_request_with_3_mm
):
"""Test state update after allocation for 3 MM items."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
# Initial state should be empty
assert len(connector._mm_datas_need_loads) == 0
# Update state for all 3 items
for i in range(3):
connector.update_state_after_alloc(mock_request_with_3_mm, index=i)
# Check state updated for all 3
assert len(connector._mm_datas_need_loads) == 3
assert "img_hash_1" in connector._mm_datas_need_loads
assert "img_hash_2" in connector._mm_datas_need_loads
assert "img_hash_3" in connector._mm_datas_need_loads
assert connector._mm_datas_need_loads["img_hash_1"] == 100
assert connector._mm_datas_need_loads["img_hash_2"] == 150
assert connector._mm_datas_need_loads["img_hash_3"] == 200
def test_build_connector_meta_3_items(
self, mock_vllm_config_producer, mock_request_with_3_mm
):
"""Test metadata building for 3 MM items."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
# Setup state for all 3 items
for i in range(3):
connector.update_state_after_alloc(mock_request_with_3_mm, index=i)
# Build metadata
scheduler_output = Mock(spec=SchedulerOutput)
metadata = connector.build_connector_meta(scheduler_output)
# Assert
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
assert len(metadata.mm_datas) == 3
assert metadata.mm_datas[0].mm_hash == "img_hash_1"
assert metadata.mm_datas[0].num_token == 100
assert metadata.mm_datas[1].mm_hash == "img_hash_2"
assert metadata.mm_datas[1].num_token == 150
assert metadata.mm_datas[2].mm_hash == "img_hash_3"
assert metadata.mm_datas[2].num_token == 200
# State should be cleared after building
assert len(connector._mm_datas_need_loads) == 0
def test_build_connector_meta_empty(self, mock_vllm_config_producer):
"""Test metadata building with empty state."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
scheduler_output = Mock(spec=SchedulerOutput)
metadata = connector.build_connector_meta(scheduler_output)
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
assert len(metadata.mm_datas) == 0
def test_state_cleared_after_metadata_build(
self, mock_vllm_config_producer, mock_request_with_3_mm
):
"""Test that state is properly cleared after building metadata."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
# Add state
for i in range(3):
connector.update_state_after_alloc(mock_request_with_3_mm, index=i)
assert len(connector._mm_datas_need_loads) == 3
# Build metadata (should clear state)
scheduler_output = Mock(spec=SchedulerOutput)
connector.build_connector_meta(scheduler_output)
# State should be empty
assert len(connector._mm_datas_need_loads) == 0
# Build again should return empty metadata
metadata2 = connector.build_connector_meta(scheduler_output)
assert len(metadata2.mm_datas) == 0
class TestCacheSaving:
"""Test encoder cache saving (producer only)."""
def test_save_caches_producer_3_items(
self, mock_vllm_config_producer, mock_request_with_3_mm, temp_storage
):
"""Test cache saving as producer for 3 different MM items."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER,
)
# Create and save 3 different caches
mm_hashes = [f.identifier for f in mock_request_with_3_mm.mm_features]
encoder_cache: dict[str, torch.Tensor] = {}
for mm_hash in mm_hashes:
encoder_cache[mm_hash] = torch.randn(10, 768)
connector.save_caches(encoder_cache, mm_hash)
# Verify all files exist using has_caches
result = connector.has_caches(mock_request_with_3_mm)
assert all(result), f"Not all caches were saved: {result}"
# Verify each file's content
for mm_hash in mm_hashes:
filename = connector._generate_filename_debug(mm_hash)
loaded = safetensors.torch.load_file(filename)
assert "ec_cache" in loaded
assert torch.allclose(loaded["ec_cache"], encoder_cache[mm_hash].cpu())
def test_save_caches_consumer_skips(self, mock_vllm_config_consumer):
"""Test cache saving is skipped for consumer."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
mm_hash = "test_hash_consumer"
encoder_cache = {mm_hash: torch.randn(10, 768)}
# Save should not raise but also not create file
connector.save_caches(encoder_cache, mm_hash)
# Verify file doesn't exist using has_caches
mock_request = MockRequest("req_consumer", [mm_hash], [10])
result = connector.has_caches(mock_request)
assert not result[0], "Consumer should not save caches"
class TestCacheLoading:
"""Test encoder cache loading (consumer)."""
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_start_load_caches_consumer_3_items(
self,
mock_vllm_config_producer,
mock_vllm_config_consumer,
mock_request_with_3_mm,
temp_storage,
):
"""Test consumer loads 3 caches from storage."""
# First, create producer to save caches
producer = ECSharedStorageConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER,
)
# Producer saves 3 caches
mm_hashes = [f.identifier for f in mock_request_with_3_mm.mm_features]
saved_caches = {}
for mm_hash in mm_hashes:
saved_caches[mm_hash] = torch.randn(10, 768)
producer.save_caches(saved_caches, mm_hash)
# Now consumer loads
consumer = ECSharedStorageConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
# Setup metadata for all 3
metadata = ECSharedStorageConnectorMetadata()
for mm_hash in mm_hashes:
metadata.add_mm_data(MMMeta.make_meta(mm_hash, 100))
consumer.bind_connector_metadata(metadata)
# Load
encoder_cache: dict[str, torch.Tensor] = {}
consumer.start_load_caches(encoder_cache=encoder_cache)
# Verify all 3 loaded
assert len(encoder_cache) == 3
for mm_hash in mm_hashes:
assert mm_hash in encoder_cache, f"{mm_hash} missing in encoder_cache"
assert encoder_cache[mm_hash].is_cuda, (
f"{mm_hash} cache is in {encoder_cache[mm_hash].device}"
)
assert torch.allclose(
encoder_cache[mm_hash].cpu(), saved_caches[mm_hash]
), f"{mm_hash} cache saved and loaded tesnor are not the same"
def test_start_load_caches_skip_existing(
self, mock_vllm_config_producer, mock_vllm_config_consumer, temp_storage
):
"""Test cache loading skips already cached items."""
# Setup: producer saves cache
producer = ECSharedStorageConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER,
)
mm_hash = "existing_hash"
saved_cache = torch.randn(10, 768)
producer.save_caches({mm_hash: saved_cache}, mm_hash)
# Consumer setup
consumer = ECSharedStorageConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
metadata = ECSharedStorageConnectorMetadata()
metadata.add_mm_data(MMMeta.make_meta(mm_hash, 100))
consumer.bind_connector_metadata(metadata)
# Pre-populate encoder_cache with different value
existing_cache = torch.randn(5, 512)
encoder_cache = {mm_hash: existing_cache}
# Load (should skip since already exists)
with patch("safetensors.torch.load_file") as mock_load:
consumer.start_load_caches(encoder_cache=encoder_cache)
# Should not call load_file since cache exists
mock_load.assert_not_called()
# Verify original cache unchanged
assert torch.equal(encoder_cache[mm_hash], existing_cache)
def test_start_load_caches_empty_metadata(self, mock_vllm_config_consumer):
"""Test loading with empty metadata does nothing."""
consumer = ECSharedStorageConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
# Setup empty metadata
metadata = ECSharedStorageConnectorMetadata()
consumer.bind_connector_metadata(metadata)
# Load (should not raise)
encoder_cache: dict[str, torch.Tensor] = {}
consumer.start_load_caches(encoder_cache=encoder_cache)
# Cache should remain empty
assert len(encoder_cache) == 0
class TestFilenameGeneration:
"""Test filename and path generation."""
def test_generate_foldername(self, mock_vllm_config_producer, temp_storage):
"""Test folder name generation."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER,
)
mm_hash = "test_folder_hash"
folder = connector._generate_foldername_debug(mm_hash)
assert folder == os.path.join(temp_storage, mm_hash)
assert os.path.isdir(folder) # Should be created
def test_generate_filename(self, mock_vllm_config_producer, temp_storage):
"""Test filename generation."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER,
)
mm_hash = "test_file_hash"
filename = connector._generate_filename_debug(mm_hash)
expected = os.path.join(temp_storage, mm_hash, "encoder_cache.safetensors")
assert filename == expected
assert os.path.isdir(os.path.dirname(filename)) # Folder created
def test_generate_filename_consistency(self, mock_vllm_config_producer):
"""Test filename generation is consistent."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER,
)
mm_hash = "consistency_hash"
filename1 = connector._generate_filename_debug(mm_hash)
filename2 = connector._generate_filename_debug(mm_hash)
assert filename1 == filename2
class TestMetadataBindingLifecycle:
"""Test metadata binding and clearing lifecycle."""
def test_bind_connector_metadata(self, mock_vllm_config_consumer):
"""Test binding connector metadata."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
metadata = ECSharedStorageConnectorMetadata()
metadata.add_mm_data(MMMeta.make_meta("hash_1", 100))
connector.bind_connector_metadata(metadata)
assert connector._connector_metadata is metadata
def test_clear_connector_metadata(self, mock_vllm_config_consumer):
"""Test clearing connector metadata."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
metadata = ECSharedStorageConnectorMetadata()
connector.bind_connector_metadata(metadata)
connector.clear_connector_metadata()
assert connector._connector_metadata is None
def test_get_connector_metadata(self, mock_vllm_config_consumer):
"""Test getting connector metadata."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
metadata = ECSharedStorageConnectorMetadata()
connector.bind_connector_metadata(metadata)
retrieved = connector._get_connector_metadata()
assert retrieved is metadata
def test_get_connector_metadata_not_set(self, mock_vllm_config_consumer):
"""Test getting metadata when not set raises."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
with pytest.raises(AssertionError):
connector._get_connector_metadata()
class TestEdgeCases:
"""Test edge cases and error handling."""
def test_save_empty_cache(self, mock_vllm_config_producer):
"""Test saving empty tensor."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER,
)
mm_hash = "empty_hash"
encoder_cache = {mm_hash: torch.empty(0)}
# Should not raise
connector.save_caches(encoder_cache, mm_hash)
def test_load_nonexistent_cache(self, mock_vllm_config_consumer):
"""Test loading cache that doesn't exist raises error."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
metadata = ECSharedStorageConnectorMetadata()
metadata.add_mm_data(MMMeta.make_meta("nonexistent_hash", 100))
connector.bind_connector_metadata(metadata)
encoder_cache: dict[str, torch.Tensor] = {}
# Should raise FileNotFoundError
with pytest.raises(FileNotFoundError):
connector.start_load_caches(encoder_cache=encoder_cache)
def test_has_caches_empty_request(self, mock_vllm_config_producer):
"""Test has_caches with request that has no MM data."""
connector = ECSharedStorageConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
mock_request = MockRequest("req_empty", [], [])
result = connector.has_caches(mock_request)
assert len(result) == 0
assert result == []
......@@ -10,6 +10,14 @@ import pytest
from transformers import AutoTokenizer
from vllm import SamplingParams
from vllm.config import (
CacheConfig,
ECTransferConfig,
KVTransferConfig,
ModelConfig,
SchedulerConfig,
VllmConfig,
)
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_default_torch_num_threads
......@@ -450,3 +458,141 @@ def test_engine_core_invalid_request_id_type():
engine_core.add_request(*engine_core.preprocess_add_request(valid_request))
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
@create_new_process_for_each_test()
@pytest.mark.parametrize(
("ec_role", "gpu_memory_utilization", "enable_prefix_caching"),
[
("ec_producer", 0.01, False),
# NOTE: ec_producer never allows prefix caching
("ec_consumer", 0.7, True),
("ec_consumer", 0.7, False),
],
)
@pytest.mark.parametrize("use_kv_connector", [False, True])
def test_encoder_instance_zero_kv_cache(
ec_role: str,
gpu_memory_utilization: float,
enable_prefix_caching: bool,
use_kv_connector: bool,
):
"""EPD (Encoder-Prefill-Decode) Encoder-cache-specific tests
This test verifies encoder-only instance initializes with 0 KV cache blocks.
Under EPD disagg mode, Encoder instances (EC producer role) only execute
vision encoder, so they don't need KV cache for text generation.
"""
# Form vllm config
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
max_model_len=512,
disable_hybrid_kv_cache_manager=True,
)
model_config = ModelConfig(
model="llava-hf/llava-1.5-7b-hf", # Multimodal model
enforce_eager=True,
trust_remote_code=True,
dtype="float16",
seed=42,
)
cache_config = CacheConfig(
block_size=16,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=0,
cache_dtype="auto",
enable_prefix_caching=enable_prefix_caching,
)
kv_transfer_config = (
KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
)
if use_kv_connector
else None
)
ec_transfer_config = ECTransferConfig(
ec_connector="ECSharedStorageConnector",
ec_role=ec_role,
ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test_encoder"},
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
kv_transfer_config=kv_transfer_config,
ec_transfer_config=ec_transfer_config,
)
executor_class = Executor.get_class(vllm_config)
print(f"executor_class: {executor_class}")
with set_default_torch_num_threads(1):
engine_core = EngineCore(
vllm_config=vllm_config, executor_class=executor_class, log_stats=True
)
# Check encoder cache manager exists
assert engine_core.scheduler.encoder_cache_manager is not None, (
"encoder_cache_manager should exist"
)
if ec_role == "ec_producer":
# Check 1: num_blocks should be 0
# NOTE: num_blocks=1 as BlockPool always needs a null_block.
kv_cache_config = engine_core.scheduler.kv_cache_manager.kv_cache_config
print(f"kv_cache_config: {kv_cache_config}")
assert kv_cache_config.num_blocks == 1, (
f"ec_producer should only have 1 KV blocks, "
f"got {kv_cache_config.num_blocks}"
)
# Check 2: kv_cache_groups should be empty
assert len(kv_cache_config.kv_cache_groups) == 0, (
f"ec_producer should have 0 KV cache groups, "
f"got {len(kv_cache_config.kv_cache_groups)}"
)
# Check 3: kv_cache_tensors should be empty
assert len(kv_cache_config.kv_cache_tensors) == 0, (
f"Encoder instance should have 0 KV cache tensors, "
f"got {len(kv_cache_config.kv_cache_tensors)}"
)
# Check 4: Verify EC connector is initialized and is producer
assert engine_core.scheduler.ec_connector is not None, (
"Encoder instance should have EC connector"
)
assert engine_core.scheduler.ec_connector.is_producer, (
"Encoder instance EC connector should be producer"
)
# Check 5: Verify chunked prefill is disabled
assert not vllm_config.scheduler_config.chunked_prefill_enabled, (
"Encoder instance should disable chunked prefill (no KV cache)"
)
elif ec_role == "ec_consumer":
# Check 1: num_blocks should be > 1
kv_cache_config = engine_core.scheduler.kv_cache_manager.kv_cache_config
print(f"kv_cache_config: {kv_cache_config}")
assert kv_cache_config.num_blocks > 1, (
f"ec_consumer should have >1 KV blocks, got {kv_cache_config.num_blocks}"
)
# Check 2: kv_cache_groups should NOT be empty
assert len(kv_cache_config.kv_cache_groups) > 0, (
f"ec_consumer should have KV cache groups, "
f"got {len(kv_cache_config.kv_cache_groups)}"
)
# Check 3: Verify EC connector is consumer
assert engine_core.scheduler.ec_connector is not None, (
"Consumer instance should have EC connector"
)
assert not engine_core.scheduler.ec_connector.is_producer, (
"Consumer instance EC connector should be consumer"
)
......@@ -9,6 +9,7 @@ from vllm.config.compilation import (
PassConfig,
)
from vllm.config.device import DeviceConfig
from vllm.config.ec_transfer import ECTransferConfig
from vllm.config.kv_events import KVEventsConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.config.load import LoadConfig
......@@ -54,6 +55,8 @@ __all__ = [
"PassConfig",
# From vllm.config.device
"DeviceConfig",
# From vllm.config.ec_transfer
"ECTransferConfig",
# From vllm.config.kv_events
"KVEventsConfig",
# From vllm.config.kv_transfer
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import uuid
from dataclasses import field
from typing import Any, Literal, get_args
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
ECProducer = Literal["ec_producer"]
ECConsumer = Literal["ec_consumer"]
ECRole = Literal[ECProducer, ECConsumer]
@config
@dataclass
class ECTransferConfig:
"""Configuration for distributed EC cache transfer."""
ec_connector: str | None = None
"""The EC connector for vLLM to transmit EC caches between vLLM instances.
"""
engine_id: str | None = None
"""The engine id for EC transfers."""
ec_buffer_device: str | None = "cuda"
"""The device used by ec connector to buffer the EC cache.
Currently only support 'cuda'."""
ec_buffer_size: float = 1e9
"""The buffer size for TorchDistributedConnector. Measured in number of
bytes. Recommended value: 1e9 (about 1GB)."""
ec_role: ECRole | None = None
"""Whether this vLLM instance produces, consumes EC cache, or both. Choices
are 'ec_producer', 'ec_consumer'."""
ec_rank: int | None = None
"""The rank of this vLLM instance in the EC cache transfer. Typical value:
0 for encoder, 1 for pd instance.
Currently only 1P1D is supported."""
ec_parallel_size: int = 1
"""The number of parallel instances for EC cache transfer. For
PyNcclConnector, this should be 2."""
ec_ip: str = "127.0.0.1"
"""The EC connector ip, used to build distributed connection."""
ec_port: int = 14579
"""The EC connector port, used to build distributed connection."""
ec_connector_extra_config: dict[str, Any] = field(default_factory=dict)
"""any extra config that the connector may need."""
ec_connector_module_path: str | None = None
"""The Python module path to dynamically load the EC connector from.
Only supported in V1."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self) -> None:
if self.engine_id is None:
self.engine_id = str(uuid.uuid4())
if self.ec_role is not None and self.ec_role not in get_args(ECRole):
raise ValueError(
f"Unsupported ec_role: {self.ec_role}. "
f"Supported roles are {get_args(ECRole)}"
)
if self.ec_connector is not None and self.ec_role is None:
raise ValueError(
"Please specify ec_role when ec_connector "
f"is set, supported roles are {get_args(ECRole)}"
)
@property
def is_ec_transfer_instance(self) -> bool:
return self.ec_connector is not None and self.ec_role in get_args(ECRole)
@property
def is_ec_producer(self) -> bool:
return self.ec_connector is not None and self.ec_role in get_args(ECProducer)
@property
def is_ec_consumer(self) -> bool:
return self.ec_connector is not None and self.ec_role in get_args(ECConsumer)
def get_from_extra_config(self, key, default) -> Any:
return self.ec_connector_extra_config.get(key, default)
......@@ -28,6 +28,7 @@ from vllm.utils import random_uuid
from .cache import CacheConfig
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
from .device import DeviceConfig
from .ec_transfer import ECTransferConfig
from .kv_events import KVEventsConfig
from .kv_transfer import KVTransferConfig
from .load import LoadConfig
......@@ -103,6 +104,8 @@ class VllmConfig:
"""The configurations for distributed KV cache transfer."""
kv_events_config: KVEventsConfig | None = None
"""The configurations for event publishing."""
ec_transfer_config: ECTransferConfig | None = None
"""The configurations for distributed EC cache transfer."""
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing, debugging or out of
# tree config registration.
......@@ -183,6 +186,10 @@ class VllmConfig:
vllm_factors.append(self.kv_transfer_config.compute_hash())
else:
vllm_factors.append("None")
if self.ec_transfer_config:
vllm_factors.append(self.ec_transfer_config.compute_hash())
else:
vllm_factors.append("None")
if self.additional_config:
if isinstance(additional_config := self.additional_config, dict):
additional_config_hash = hashlib.md5(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.ec_transfer.ec_transfer_state import (
ensure_ec_transfer_initialized,
get_ec_transfer,
has_ec_transfer,
)
__all__ = [
"get_ec_transfer",
"ensure_ec_transfer_initialized",
"has_ec_transfer",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
ECConnectorBase Class for Distributed Encoder Cache &
P2P Encoder cache communication in V1
The class provides the following primitives:
Scheduler-side: runs in the scheduler, binds metadata, which
is used by the worker-side to load/save Encoder cache.
check_caches_exist() - Check whether Encoder cache of requests exist
update_state_after_alloc() - update ECConnector state after
allocate. This will decide to load the cache or not
request_finished() - called when a request is finished,
free the cache with the requests
Worker-side: runs in each worker, loads/saves Encoder Cache to/from
the Connector based on the metadata.
start_load_ec() - starts loading all ECs (maybe async)
wait_for_save() - blocks until all saves are done
get_finished() - called with ids of finished requests, returns
ids of requests that have completed async sending/recving.
"""
import enum
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
import torch
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ECConnectorOutput
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
class ECConnectorRole(enum.Enum):
# Connector running in the scheduler process
SCHEDULER = 0
# Connector running in the worker process
WORKER = 1
class ECConnectorMetadata(ABC): # noqa: B024
"""
Abstract Metadata used to communicate between the
Scheduler ECConnector and Worker ECConnector.
"""
pass
class ECConnectorBase(ABC):
def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole):
self._connector_metadata: ECConnectorMetadata | None = None
self._vllm_config = vllm_config
self._role = role
if vllm_config.ec_transfer_config is not None:
self._is_producer = vllm_config.ec_transfer_config.is_ec_producer
else:
raise ValueError("ec_transfer_config must be set for ECConnectorBase")
@property
def role(self) -> ECConnectorRole:
return self._role
@property
def is_producer(self) -> bool:
return self._is_producer
# ==============================
# Worker-side methods
# ==============================
def bind_connector_metadata(self, connector_metadata: ECConnectorMetadata) -> None:
"""Set the connector metadata from the scheduler.
This function should be called by the model runner every time
before the model execution. The metadata will be used for runtime
EC cache loading.
Args:
connector_metadata (dict): the connector metadata.
"""
self._connector_metadata = connector_metadata
def clear_connector_metadata(self) -> None:
"""Clear the connector metadata.
This function should be called by the model runner every time
after the model execution.
"""
self._connector_metadata = None
def _get_connector_metadata(self) -> ECConnectorMetadata:
"""Get the connector metadata.
This function should only be called inside the connector.
Returns:
ConnectorMetadata: the connector metadata.
"""
# Should only be called while set to valid metadata.
assert self._connector_metadata is not None
return self._connector_metadata
def register_caches(
self,
ec_caches: dict[str, torch.Tensor],
):
"""
Initialize with the EC caches.
Args:
ec_caches: dictionary of encoder cache
"""
# TODO: Implement this later for P2P feature
return
@abstractmethod
def start_load_caches(
self, encoder_cache: dict[str, torch.Tensor], **kwargs
) -> None:
"""
Start loading the cache from the connector into vLLM's encoder cache.
This method loads the encoder cache based on metadata provided by the scheduler.
It is called before `_gather_mm_embeddings` for the EC Connector. For EC,
the `encoder_cache` and `mm_hash` are stored in `kwargs`.
Args:
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
data hashes (`mm_hash`) to encoder cache tensors.
kwargs (dict): Additional keyword arguments for the connector.
"""
pass
@abstractmethod
def save_caches(
self, encoder_cache: dict[str, torch.Tensor], mm_hash: str, **kwargs
) -> None:
"""
Save the encoder cache to the connector.
This method saves the encoder cache from the worker's local storage
to shared storage or another external connector.
Args:
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
data hashes (`mm_hash`) to encoder cache tensors.
mm_hash (str): The hash of the multimodal data whose cache is being saved.
kwargs (dict): Additional keyword arguments for the connector.
"""
pass
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens on the worker.
The scheduler process (via the Executors) will use this output
to track which workers are done.
Returns:
ids of requests that have finished asynchronous transfer
(requests that previously returned True from request_finished()),
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
return None, None
# ==============================
# Scheduler-side methods
# ==============================
@abstractmethod
def has_caches(
self,
request: "Request",
) -> list[bool]:
"""
Check if encoder cache exists for each mm data of requests
Args:
request (Request): the request object.
Returns:
A list bool where ith value is True if cache exist for
ith mm_data of requests
"""
pass
@abstractmethod
def update_state_after_alloc(self, request: "Request", index: int):
"""
Update ECConnector state to decide allocate cache for requests
Args:
request (Request): the request object.
"""
pass
@abstractmethod
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> ECConnectorMetadata:
"""
Build the connector metadata for this step.
This function should NOT modify fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
pass
def update_connector_output(self, connector_output: ECConnectorOutput):
"""
Update ECConnector state from worker-side connectors output.
Args:
connector_output (ECConnectorOutput): the worker-side
connectors output.
"""
return
def request_finished(
self, request: "Request"
) -> tuple[bool, dict[str, Any] | None]:
"""
Called when a request has finished, before its encoder cache is freed.
Returns:
True if the request is being saved/sent asynchronously and cached
should not be freed until the request_id is returned from
get_finished().
"""
return False, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment