Unverified Commit e5bfcb6a authored by Pan Li's avatar Pan Li Committed by GitHub
Browse files

[BugFix][PD]: make example proxy usable with P2pNcclConnector (#26628)


Signed-off-by: default avatarPAN <1162953505@qq.com>
parent 22924383
...@@ -5,11 +5,12 @@ import argparse ...@@ -5,11 +5,12 @@ import argparse
import asyncio import asyncio
import logging import logging
import os import os
import time
import uuid
from urllib.parse import urlparse
import aiohttp import aiohttp
from quart import Quart, Response, make_response, request from quart import Quart, Response, make_response, request
from rate_limiter import RateLimiter
from request_queue import RequestQueue
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -24,26 +25,8 @@ def parse_args(): ...@@ -24,26 +25,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--timeout", "--timeout",
type=float, type=float,
default=300, default=6 * 60 * 60,
help="Timeout for backend service requests in seconds (default: 300)", help="Timeout for backend service requests in seconds (default: 21600)",
)
parser.add_argument(
"--max-concurrent",
type=int,
default=100,
help="Maximum concurrent requests to backend services (default: 100)",
)
parser.add_argument(
"--queue-size",
type=int,
default=500,
help="Maximum number of requests in the queue (default: 500)",
)
parser.add_argument(
"--rate-limit",
type=int,
default=40,
help="Maximum requests per second (default: 40)",
) )
parser.add_argument( parser.add_argument(
"--port", "--port",
...@@ -54,14 +37,32 @@ def parse_args(): ...@@ -54,14 +37,32 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--prefill-url", "--prefill-url",
type=str, type=str,
default="http://localhost:8100/v1/completions", default="http://localhost:8100",
help="Prefill service endpoint URL", help="Prefill service base URL (protocol + host[:port])",
) )
parser.add_argument( parser.add_argument(
"--decode-url", "--decode-url",
type=str, type=str,
default="http://localhost:8200/v1/completions", default="http://localhost:8200",
help="Decode service endpoint URL", help="Decode service base URL (protocol + host[:port])",
)
parser.add_argument(
"--kv-host",
type=str,
default="localhost",
help="Hostname or IP used by KV transfer (default: localhost)",
)
parser.add_argument(
"--prefill-kv-port",
type=int,
default=14579,
help="Prefill KV port (default: 14579)",
)
parser.add_argument(
"--decode-kv-port",
type=int,
default=14580,
help="Decode KV port (default: 14580)",
) )
return parser.parse_args() return parser.parse_args()
...@@ -73,70 +74,129 @@ def main(): ...@@ -73,70 +74,129 @@ def main():
# Initialize configuration using command line parameters # Initialize configuration using command line parameters
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout) AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout)
MAX_CONCURRENT_REQUESTS = args.max_concurrent
REQUEST_QUEUE_SIZE = args.queue_size
RATE_LIMIT = args.rate_limit
PREFILL_SERVICE_URL = args.prefill_url PREFILL_SERVICE_URL = args.prefill_url
DECODE_SERVICE_URL = args.decode_url DECODE_SERVICE_URL = args.decode_url
PORT = args.port PORT = args.port
app = Quart(__name__) PREFILL_KV_ADDR = f"{args.kv_host}:{args.prefill_kv_port}"
DECODE_KV_ADDR = f"{args.kv_host}:{args.decode_kv_port}"
logger.info(
"Proxy resolved KV addresses -> prefill: %s, decode: %s",
PREFILL_KV_ADDR,
DECODE_KV_ADDR,
)
# Initialize the rate limiter and request queue app = Quart(__name__)
rate_limiter = RateLimiter(RATE_LIMIT)
request_queue = RequestQueue(MAX_CONCURRENT_REQUESTS, REQUEST_QUEUE_SIZE)
# Attach the configuration object to the application instance # Attach the configuration object to the application instance so helper
# coroutines can read the resolved backend URLs and timeouts without using
# globals.
app.config.update( app.config.update(
{ {
"AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT, "AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT,
"rate_limiter": rate_limiter,
"request_queue": request_queue,
"PREFILL_SERVICE_URL": PREFILL_SERVICE_URL, "PREFILL_SERVICE_URL": PREFILL_SERVICE_URL,
"DECODE_SERVICE_URL": DECODE_SERVICE_URL, "DECODE_SERVICE_URL": DECODE_SERVICE_URL,
"PREFILL_KV_ADDR": PREFILL_KV_ADDR,
"DECODE_KV_ADDR": DECODE_KV_ADDR,
} }
) )
# Start queue processing on app startup def _normalize_base_url(url: str) -> str:
@app.before_serving """Remove any trailing slash so path joins behave predictably."""
async def startup(): return url.rstrip("/")
"""Start request processing task when app starts serving"""
asyncio.create_task(request_queue.process()) def _get_host_port(url: str) -> str:
"""Return the hostname:port portion for logging and KV headers."""
parsed = urlparse(url)
host = parsed.hostname or "localhost"
port = parsed.port
if port is None:
port = 80 if parsed.scheme == "http" else 443
return f"{host}:{port}"
PREFILL_BASE = _normalize_base_url(PREFILL_SERVICE_URL)
DECODE_BASE = _normalize_base_url(DECODE_SERVICE_URL)
KV_TARGET = _get_host_port(DECODE_SERVICE_URL)
async def forward_request(url, data): def _build_headers(request_id: str) -> dict[str, str]:
"""Forward request to backend service with rate limiting and error handling""" """Construct the headers expected by vLLM's P2P disagg connector."""
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} headers: dict[str, str] = {"X-Request-Id": request_id, "X-KV-Target": KV_TARGET}
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
# Use rate limiter as context manager async def _run_prefill(
request_path: str,
payload: dict,
headers: dict[str, str],
request_id: str,
):
url = f"{PREFILL_BASE}{request_path}"
start_ts = time.perf_counter()
logger.info("[prefill] start request_id=%s url=%s", request_id, url)
try:
async with ( async with (
rate_limiter,
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session, aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
session.post(url=url, json=payload, headers=headers) as resp,
):
if resp.status != 200:
error_text = await resp.text()
raise RuntimeError(
f"Prefill backend error {resp.status}: {error_text}"
)
await resp.read()
logger.info(
"[prefill] done request_id=%s status=%s elapsed=%.2fs",
request_id,
resp.status,
time.perf_counter() - start_ts,
)
except asyncio.TimeoutError as exc:
raise RuntimeError(f"Prefill service timeout at {url}") from exc
except aiohttp.ClientError as exc:
raise RuntimeError(f"Prefill service unavailable at {url}") from exc
async def _stream_decode(
request_path: str,
payload: dict,
headers: dict[str, str],
request_id: str,
): ):
url = f"{DECODE_BASE}{request_path}"
# Stream tokens from the decode service once the prefill stage has
# materialized KV caches on the target workers.
logger.info("[decode] start request_id=%s url=%s", request_id, url)
try: try:
async with session.post( async with (
url=url, json=data, headers=headers aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
) as response: session.post(url=url, json=payload, headers=headers) as resp,
if response.status == 200: ):
# Stream response chunks if resp.status != 200:
async for chunk_bytes in response.content.iter_chunked(1024): error_text = await resp.text()
yield chunk_bytes
else:
# Handle backend service errors
error_text = await response.text()
logger.error( logger.error(
"Backend service error: %s - %s", "Decode backend error %s - %s", resp.status, error_text
response.status, )
error_text, err_msg = (
) '{"error": "Decode backend error ' + str(resp.status) + '"}'
yield b'{"error": "Backend service error"}' )
except aiohttp.ClientError as e: yield err_msg.encode()
# Handle connection errors return
logger.error("Connection error to %s: %s", url, str(e)) logger.info(
yield b'{"error": "Service unavailable"}' "[decode] streaming response request_id=%s status=%s",
request_id,
resp.status,
)
async for chunk_bytes in resp.content.iter_chunked(1024):
yield chunk_bytes
logger.info("[decode] finished streaming request_id=%s", request_id)
except asyncio.TimeoutError: except asyncio.TimeoutError:
# Handle timeout errors logger.error("Decode service timeout at %s", url)
logger.error("Timeout connecting to %s", url) yield b'{"error": "Decode service timeout"}'
yield b'{"error": "Service timeout"}' except aiohttp.ClientError as exc:
logger.error("Decode service error at %s: %s", url, exc)
yield b'{"error": "Decode service unavailable"}'
async def process_request(): async def process_request():
"""Process a single request through prefill and decode stages""" """Process a single request through prefill and decode stages"""
...@@ -146,13 +206,27 @@ def main(): ...@@ -146,13 +206,27 @@ def main():
# Create prefill request (max_tokens=1) # Create prefill request (max_tokens=1)
prefill_request = original_request_data.copy() prefill_request = original_request_data.copy()
prefill_request["max_tokens"] = 1 prefill_request["max_tokens"] = 1
if "max_completion_tokens" in prefill_request:
prefill_request["max_completion_tokens"] = 1
# Execute prefill stage # Execute prefill stage
async for _ in forward_request(PREFILL_SERVICE_URL, prefill_request): # The request id encodes both KV socket addresses so the backend can
continue # shuttle tensors directly via NCCL once the prefill response
# completes.
request_id = (
f"___prefill_addr_{PREFILL_KV_ADDR}___decode_addr_"
f"{DECODE_KV_ADDR}_{uuid.uuid4().hex}"
)
headers = _build_headers(request_id)
await _run_prefill(request.path, prefill_request, headers, request_id)
# Execute decode stage and stream response # Execute decode stage and stream response
generator = forward_request(DECODE_SERVICE_URL, original_request_data) # Pass the unmodified user request so the decode phase can continue
# sampling with the already-populated KV cache.
generator = _stream_decode(
request.path, original_request_data, headers, request_id
)
response = await make_response(generator) response = await make_response(generator)
response.timeout = None # Disable timeout for streaming response response.timeout = None # Disable timeout for streaming response
return response return response
...@@ -168,23 +242,10 @@ def main(): ...@@ -168,23 +242,10 @@ def main():
@app.route("/v1/completions", methods=["POST"]) @app.route("/v1/completions", methods=["POST"])
async def handle_request(): async def handle_request():
"""Handle incoming API requests with concurrency and rate limiting""" """Handle incoming API requests with concurrency and rate limiting"""
# Create task for request processing
task = asyncio.create_task(process_request())
# Enqueue request or reject if queue is full
if not await request_queue.enqueue(task):
return Response(
response=b'{"error": "Server busy, try again later"}',
status=503,
content_type="application/json",
)
try: try:
# Return the response from the processing task return await process_request()
return await task
except asyncio.CancelledError: except asyncio.CancelledError:
# Handle task cancellation (timeout or queue full) logger.warning("Request cancelled")
logger.warning("Request cancelled due to timeout or queue full")
return Response( return Response(
response=b'{"error": "Request cancelled"}', response=b'{"error": "Request cancelled"}',
status=503, status=503,
......
...@@ -24,7 +24,14 @@ cleanup() { ...@@ -24,7 +24,14 @@ cleanup() {
exit 0 exit 0
} }
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
if [[ -z "${VLLM_HOST_IP:-}" ]]; then
export VLLM_HOST_IP=127.0.0.1
echo "Using default VLLM_HOST_IP=127.0.0.1 (override by exporting VLLM_HOST_IP before running this script)"
else
echo "Using provided VLLM_HOST_IP=${VLLM_HOST_IP}"
fi
# install quart first -- required for disagg prefill proxy serve # install quart first -- required for disagg prefill proxy serve
if python3 -c "import quart" &> /dev/null; then if python3 -c "import quart" &> /dev/null; then
...@@ -38,7 +45,7 @@ fi ...@@ -38,7 +45,7 @@ fi
wait_for_server() { wait_for_server() {
local port=$1 local port=$1
timeout 1200 bash -c " timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do until curl -i localhost:${port}/v1/models > /dev/null; do
sleep 1 sleep 1
done" && return 0 || return 1 done" && return 0 || return 1
} }
...@@ -48,21 +55,23 @@ wait_for_server() { ...@@ -48,21 +55,23 @@ wait_for_server() {
# prefilling instance, which is the KV producer # prefilling instance, which is the KV producer
CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \ CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \
--host 0.0.0.0 \
--port 8100 \ --port 8100 \
--max-model-len 100 \ --max-model-len 100 \
--gpu-memory-utilization 0.8 \ --gpu-memory-utilization 0.8 \
--trust-remote-code \ --trust-remote-code \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' & '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":"1e9","kv_port":"14579","kv_connector_extra_config":{"proxy_ip":"'"$VLLM_HOST_IP"'","proxy_port":"30001","http_ip":"'"$VLLM_HOST_IP"'","http_port":"8100","send_type":"PUT_ASYNC"}}' &
# decoding instance, which is the KV consumer # decoding instance, which is the KV consumer
CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \ CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \
--host 0.0.0.0 \
--port 8200 \ --port 8200 \
--max-model-len 100 \ --max-model-len 100 \
--gpu-memory-utilization 0.8 \ --gpu-memory-utilization 0.8 \
--trust-remote-code \ --trust-remote-code \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":"1e10","kv_port":"14580","kv_connector_extra_config":{"proxy_ip":"'"$VLLM_HOST_IP"'","proxy_port":"30001","http_ip":"'"$VLLM_HOST_IP"'","http_port":"8200","send_type":"PUT_ASYNC"}}' &
# wait until prefill and decode instances are ready # wait until prefill and decode instances are ready
wait_for_server 8100 wait_for_server 8100
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment