Unverified Commit e5bfcb6a authored by Pan Li's avatar Pan Li Committed by GitHub
Browse files

[BugFix][PD]: make example proxy usable with P2pNcclConnector (#26628)


Signed-off-by: default avatarPAN <1162953505@qq.com>
parent 22924383
...@@ -5,11 +5,12 @@ import argparse ...@@ -5,11 +5,12 @@ import argparse
import asyncio import asyncio
import logging import logging
import os import os
import time
import uuid
from urllib.parse import urlparse
import aiohttp import aiohttp
from quart import Quart, Response, make_response, request from quart import Quart, Response, make_response, request
from rate_limiter import RateLimiter
from request_queue import RequestQueue
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -24,26 +25,8 @@ def parse_args(): ...@@ -24,26 +25,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--timeout", "--timeout",
type=float, type=float,
default=300, default=6 * 60 * 60,
help="Timeout for backend service requests in seconds (default: 300)", help="Timeout for backend service requests in seconds (default: 21600)",
)
parser.add_argument(
"--max-concurrent",
type=int,
default=100,
help="Maximum concurrent requests to backend services (default: 100)",
)
parser.add_argument(
"--queue-size",
type=int,
default=500,
help="Maximum number of requests in the queue (default: 500)",
)
parser.add_argument(
"--rate-limit",
type=int,
default=40,
help="Maximum requests per second (default: 40)",
) )
parser.add_argument( parser.add_argument(
"--port", "--port",
...@@ -54,14 +37,32 @@ def parse_args(): ...@@ -54,14 +37,32 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--prefill-url", "--prefill-url",
type=str, type=str,
default="http://localhost:8100/v1/completions", default="http://localhost:8100",
help="Prefill service endpoint URL", help="Prefill service base URL (protocol + host[:port])",
) )
parser.add_argument( parser.add_argument(
"--decode-url", "--decode-url",
type=str, type=str,
default="http://localhost:8200/v1/completions", default="http://localhost:8200",
help="Decode service endpoint URL", help="Decode service base URL (protocol + host[:port])",
)
parser.add_argument(
"--kv-host",
type=str,
default="localhost",
help="Hostname or IP used by KV transfer (default: localhost)",
)
parser.add_argument(
"--prefill-kv-port",
type=int,
default=14579,
help="Prefill KV port (default: 14579)",
)
parser.add_argument(
"--decode-kv-port",
type=int,
default=14580,
help="Decode KV port (default: 14580)",
) )
return parser.parse_args() return parser.parse_args()
...@@ -73,70 +74,129 @@ def main(): ...@@ -73,70 +74,129 @@ def main():
# Initialize configuration using command line parameters # Initialize configuration using command line parameters
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout) AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout)
MAX_CONCURRENT_REQUESTS = args.max_concurrent
REQUEST_QUEUE_SIZE = args.queue_size
RATE_LIMIT = args.rate_limit
PREFILL_SERVICE_URL = args.prefill_url PREFILL_SERVICE_URL = args.prefill_url
DECODE_SERVICE_URL = args.decode_url DECODE_SERVICE_URL = args.decode_url
PORT = args.port PORT = args.port
app = Quart(__name__) PREFILL_KV_ADDR = f"{args.kv_host}:{args.prefill_kv_port}"
DECODE_KV_ADDR = f"{args.kv_host}:{args.decode_kv_port}"
# Initialize the rate limiter and request queue logger.info(
rate_limiter = RateLimiter(RATE_LIMIT) "Proxy resolved KV addresses -> prefill: %s, decode: %s",
request_queue = RequestQueue(MAX_CONCURRENT_REQUESTS, REQUEST_QUEUE_SIZE) PREFILL_KV_ADDR,
DECODE_KV_ADDR,
)
app = Quart(__name__)
# Attach the configuration object to the application instance # Attach the configuration object to the application instance so helper
# coroutines can read the resolved backend URLs and timeouts without using
# globals.
app.config.update( app.config.update(
{ {
"AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT, "AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT,
"rate_limiter": rate_limiter,
"request_queue": request_queue,
"PREFILL_SERVICE_URL": PREFILL_SERVICE_URL, "PREFILL_SERVICE_URL": PREFILL_SERVICE_URL,
"DECODE_SERVICE_URL": DECODE_SERVICE_URL, "DECODE_SERVICE_URL": DECODE_SERVICE_URL,
"PREFILL_KV_ADDR": PREFILL_KV_ADDR,
"DECODE_KV_ADDR": DECODE_KV_ADDR,
} }
) )
# Start queue processing on app startup def _normalize_base_url(url: str) -> str:
@app.before_serving """Remove any trailing slash so path joins behave predictably."""
async def startup(): return url.rstrip("/")
"""Start request processing task when app starts serving"""
asyncio.create_task(request_queue.process()) def _get_host_port(url: str) -> str:
"""Return the hostname:port portion for logging and KV headers."""
async def forward_request(url, data): parsed = urlparse(url)
"""Forward request to backend service with rate limiting and error handling""" host = parsed.hostname or "localhost"
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} port = parsed.port
if port is None:
# Use rate limiter as context manager port = 80 if parsed.scheme == "http" else 443
async with ( return f"{host}:{port}"
rate_limiter,
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session, PREFILL_BASE = _normalize_base_url(PREFILL_SERVICE_URL)
): DECODE_BASE = _normalize_base_url(DECODE_SERVICE_URL)
try: KV_TARGET = _get_host_port(DECODE_SERVICE_URL)
async with session.post(
url=url, json=data, headers=headers def _build_headers(request_id: str) -> dict[str, str]:
) as response: """Construct the headers expected by vLLM's P2P disagg connector."""
if response.status == 200: headers: dict[str, str] = {"X-Request-Id": request_id, "X-KV-Target": KV_TARGET}
# Stream response chunks api_key = os.environ.get("OPENAI_API_KEY")
async for chunk_bytes in response.content.iter_chunked(1024): if api_key:
yield chunk_bytes headers["Authorization"] = f"Bearer {api_key}"
else: return headers
# Handle backend service errors
error_text = await response.text() async def _run_prefill(
logger.error( request_path: str,
"Backend service error: %s - %s", payload: dict,
response.status, headers: dict[str, str],
error_text, request_id: str,
) ):
yield b'{"error": "Backend service error"}' url = f"{PREFILL_BASE}{request_path}"
except aiohttp.ClientError as e: start_ts = time.perf_counter()
# Handle connection errors logger.info("[prefill] start request_id=%s url=%s", request_id, url)
logger.error("Connection error to %s: %s", url, str(e)) try:
yield b'{"error": "Service unavailable"}' async with (
except asyncio.TimeoutError: aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
# Handle timeout errors session.post(url=url, json=payload, headers=headers) as resp,
logger.error("Timeout connecting to %s", url) ):
yield b'{"error": "Service timeout"}' if resp.status != 200:
error_text = await resp.text()
raise RuntimeError(
f"Prefill backend error {resp.status}: {error_text}"
)
await resp.read()
logger.info(
"[prefill] done request_id=%s status=%s elapsed=%.2fs",
request_id,
resp.status,
time.perf_counter() - start_ts,
)
except asyncio.TimeoutError as exc:
raise RuntimeError(f"Prefill service timeout at {url}") from exc
except aiohttp.ClientError as exc:
raise RuntimeError(f"Prefill service unavailable at {url}") from exc
async def _stream_decode(
request_path: str,
payload: dict,
headers: dict[str, str],
request_id: str,
):
url = f"{DECODE_BASE}{request_path}"
# Stream tokens from the decode service once the prefill stage has
# materialized KV caches on the target workers.
logger.info("[decode] start request_id=%s url=%s", request_id, url)
try:
async with (
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
session.post(url=url, json=payload, headers=headers) as resp,
):
if resp.status != 200:
error_text = await resp.text()
logger.error(
"Decode backend error %s - %s", resp.status, error_text
)
err_msg = (
'{"error": "Decode backend error ' + str(resp.status) + '"}'
)
yield err_msg.encode()
return
logger.info(
"[decode] streaming response request_id=%s status=%s",
request_id,
resp.status,
)
async for chunk_bytes in resp.content.iter_chunked(1024):
yield chunk_bytes
logger.info("[decode] finished streaming request_id=%s", request_id)
except asyncio.TimeoutError:
logger.error("Decode service timeout at %s", url)
yield b'{"error": "Decode service timeout"}'
except aiohttp.ClientError as exc:
logger.error("Decode service error at %s: %s", url, exc)
yield b'{"error": "Decode service unavailable"}'
async def process_request(): async def process_request():
"""Process a single request through prefill and decode stages""" """Process a single request through prefill and decode stages"""
...@@ -146,13 +206,27 @@ def main(): ...@@ -146,13 +206,27 @@ def main():
# Create prefill request (max_tokens=1) # Create prefill request (max_tokens=1)
prefill_request = original_request_data.copy() prefill_request = original_request_data.copy()
prefill_request["max_tokens"] = 1 prefill_request["max_tokens"] = 1
if "max_completion_tokens" in prefill_request:
prefill_request["max_completion_tokens"] = 1
# Execute prefill stage # Execute prefill stage
async for _ in forward_request(PREFILL_SERVICE_URL, prefill_request): # The request id encodes both KV socket addresses so the backend can
continue # shuttle tensors directly via NCCL once the prefill response
# completes.
request_id = (
f"___prefill_addr_{PREFILL_KV_ADDR}___decode_addr_"
f"{DECODE_KV_ADDR}_{uuid.uuid4().hex}"
)
headers = _build_headers(request_id)
await _run_prefill(request.path, prefill_request, headers, request_id)
# Execute decode stage and stream response # Execute decode stage and stream response
generator = forward_request(DECODE_SERVICE_URL, original_request_data) # Pass the unmodified user request so the decode phase can continue
# sampling with the already-populated KV cache.
generator = _stream_decode(
request.path, original_request_data, headers, request_id
)
response = await make_response(generator) response = await make_response(generator)
response.timeout = None # Disable timeout for streaming response response.timeout = None # Disable timeout for streaming response
return response return response
...@@ -168,23 +242,10 @@ def main(): ...@@ -168,23 +242,10 @@ def main():
@app.route("/v1/completions", methods=["POST"]) @app.route("/v1/completions", methods=["POST"])
async def handle_request(): async def handle_request():
"""Handle incoming API requests with concurrency and rate limiting""" """Handle incoming API requests with concurrency and rate limiting"""
# Create task for request processing
task = asyncio.create_task(process_request())
# Enqueue request or reject if queue is full
if not await request_queue.enqueue(task):
return Response(
response=b'{"error": "Server busy, try again later"}',
status=503,
content_type="application/json",
)
try: try:
# Return the response from the processing task return await process_request()
return await task
except asyncio.CancelledError: except asyncio.CancelledError:
# Handle task cancellation (timeout or queue full) logger.warning("Request cancelled")
logger.warning("Request cancelled due to timeout or queue full")
return Response( return Response(
response=b'{"error": "Request cancelled"}', response=b'{"error": "Request cancelled"}',
status=503, status=503,
......
...@@ -24,7 +24,14 @@ cleanup() { ...@@ -24,7 +24,14 @@ cleanup() {
exit 0 exit 0
} }
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
if [[ -z "${VLLM_HOST_IP:-}" ]]; then
export VLLM_HOST_IP=127.0.0.1
echo "Using default VLLM_HOST_IP=127.0.0.1 (override by exporting VLLM_HOST_IP before running this script)"
else
echo "Using provided VLLM_HOST_IP=${VLLM_HOST_IP}"
fi
# install quart first -- required for disagg prefill proxy serve # install quart first -- required for disagg prefill proxy serve
if python3 -c "import quart" &> /dev/null; then if python3 -c "import quart" &> /dev/null; then
...@@ -38,7 +45,7 @@ fi ...@@ -38,7 +45,7 @@ fi
wait_for_server() { wait_for_server() {
local port=$1 local port=$1
timeout 1200 bash -c " timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do until curl -i localhost:${port}/v1/models > /dev/null; do
sleep 1 sleep 1
done" && return 0 || return 1 done" && return 0 || return 1
} }
...@@ -48,21 +55,23 @@ wait_for_server() { ...@@ -48,21 +55,23 @@ wait_for_server() {
# prefilling instance, which is the KV producer # prefilling instance, which is the KV producer
CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \ CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \
--host 0.0.0.0 \
--port 8100 \ --port 8100 \
--max-model-len 100 \ --max-model-len 100 \
--gpu-memory-utilization 0.8 \ --gpu-memory-utilization 0.8 \
--trust-remote-code \ --trust-remote-code \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' & '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":"1e9","kv_port":"14579","kv_connector_extra_config":{"proxy_ip":"'"$VLLM_HOST_IP"'","proxy_port":"30001","http_ip":"'"$VLLM_HOST_IP"'","http_port":"8100","send_type":"PUT_ASYNC"}}' &
# decoding instance, which is the KV consumer # decoding instance, which is the KV consumer
CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \ CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \
--host 0.0.0.0 \
--port 8200 \ --port 8200 \
--max-model-len 100 \ --max-model-len 100 \
--gpu-memory-utilization 0.8 \ --gpu-memory-utilization 0.8 \
--trust-remote-code \ --trust-remote-code \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":"1e10","kv_port":"14580","kv_connector_extra_config":{"proxy_ip":"'"$VLLM_HOST_IP"'","proxy_port":"30001","http_ip":"'"$VLLM_HOST_IP"'","http_port":"8200","send_type":"PUT_ASYNC"}}' &
# wait until prefill and decode instances are ready # wait until prefill and decode instances are ready
wait_for_server 8100 wait_for_server 8100
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment