Unverified Commit 81b8af8d authored by Sean SH Choi's avatar Sean SH Choi Committed by GitHub
Browse files

fix: update router standalone to use updated vLLM API (#4079)


Signed-off-by: default avatarSean Choi <sechoi@nvidia.com>
parent 98842a03
......@@ -35,6 +35,7 @@ from vllm.entrypoints.openai.protocol import (
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.inputs.data import TokensPrompt
from vllm.transformers_utils.tokenizer import get_tokenizer
from worker import VllmWorkers
......@@ -78,9 +79,11 @@ class ServiceAPI:
or self.http_client is None
):
return ErrorResponse(
message="Service not ready",
type="service_unavailable",
code=503,
error={
"message": "Service not ready",
"type": "service_unavailable",
"code": 503,
},
)
try:
......@@ -95,9 +98,11 @@ class ServiceAPI:
max_tokens_value = request.max_tokens
else:
return ErrorResponse(
message="Either max_tokens or max_completion_tokens must be specified",
type="invalid_request_error",
code=400,
error={
"message": "Either max_tokens or max_completion_tokens must be specified",
"type": "invalid_request_error",
"code": 400,
},
)
# Use vLLM's preprocessing to convert chat to prompt
......@@ -119,9 +124,9 @@ class ServiceAPI:
# Convert request to sampling parameters with our determined max_tokens
sampling_params = request.to_sampling_params(
default_max_tokens=max_tokens_value,
max_tokens=max_tokens_value,
logits_processor_pattern=None,
default_sampling_params=None,
default_sampling_params={},
)
# Get best worker using HTTP request to router
......@@ -129,9 +134,11 @@ class ServiceAPI:
num_tokens = len(tokens)
if num_tokens == 0:
return ErrorResponse(
message="Input prompt is empty",
type="invalid_request_error",
code=400,
error={
"message": "Input prompt is empty",
"type": "invalid_request_error",
"code": 400,
}
)
# It is much preferred to communicate block hashes to the router instead of
......@@ -161,9 +168,11 @@ class ServiceAPI:
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.error(f"Router request failed: {e}")
return ErrorResponse(
message="Router service unavailable",
type="service_unavailable",
code=503,
error={
"message": "Router service unavailable",
"type": "service_unavailable",
"code": 503,
}
)
logger.info(f"Selected worker {best_worker_id} for request")
......@@ -172,9 +181,13 @@ class ServiceAPI:
request_id = f"chatcmpl-{uuid.uuid4()}"
request_metadata = RequestResponseMetadata(request_id=request_id)
# Convert engine_prompt dict to TokensPrompt object
tokens_prompt = TokensPrompt(prompt_token_ids=tokens)
logger.info(f"Created TokensPrompt with {len(tokens)} tokens")
# Get the generator from the selected worker with sampling params
result_generator = self.workers.direct(
engine_prompt, best_worker_id, sampling_params
tokens_prompt, best_worker_id, sampling_params
)
assert request.stream
......@@ -188,6 +201,7 @@ class ServiceAPI:
conversation,
self.tokenizer,
request_metadata,
enable_force_include_usage=False,
),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
......@@ -195,7 +209,9 @@ class ServiceAPI:
except Exception as e:
logger.error(f"Error processing request: {e}")
return ErrorResponse(message=str(e), type="internal_error", code=500)
return ErrorResponse(
error={"message": str(e), "type": "internal_error", "code": 500}
)
async def initialize_services(self):
"""Initialize workers, HTTP client, and OpenAI serving components"""
......
......@@ -41,7 +41,7 @@ class RouterResponse(BaseModel):
class LoadMetrics(BaseModel):
gpu_cache_usage: float
kv_cache_usage: float
num_waiting_reqs: int
......@@ -101,7 +101,7 @@ class KvRouter:
try:
metrics_dict = self.load_listeners[worker_id].recv_json(zmq.NOBLOCK)
metrics = LoadMetrics.model_validate(metrics_dict)
self.kv_usages[worker_id] = metrics.gpu_cache_usage
self.kv_usages[worker_id] = metrics.kv_cache_usage
self.waitings[worker_id] = metrics.num_waiting_reqs
except zmq.Again:
pass
......
......@@ -20,7 +20,13 @@ import uuid
from typing import AsyncGenerator, Optional
import zmq
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.config import (
CacheConfig,
ModelConfig,
ObservabilityConfig,
SchedulerConfig,
VllmConfig,
)
from vllm.distributed.kv_events import KVEventsConfig
from vllm.inputs.data import TokensPrompt
from vllm.outputs import RequestOutput
......@@ -50,7 +56,7 @@ class MetricsPublisher(StatLoggerBase):
# Send metrics over ZMQ
metrics_data = {
"num_waiting_reqs": scheduler_stats.num_waiting_reqs,
"gpu_cache_usage": scheduler_stats.gpu_cache_usage,
"kv_cache_usage": scheduler_stats.kv_cache_usage,
}
self.socket.send_json(metrics_data)
......@@ -108,11 +114,14 @@ class VllmWorkers:
scheduler_cls="vllm.v1.core.sched.scheduler.Scheduler"
)
observability_config = ObservabilityConfig()
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
kv_events_config=kv_events_config,
scheduler_config=scheduler_config,
observability_config=observability_config,
)
self.llms.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment