Unverified Commit 81b8af8d authored by Sean SH Choi's avatar Sean SH Choi Committed by GitHub
Browse files

fix: update router standalone to use updated vLLM API (#4079)


Signed-off-by: default avatarSean Choi <sechoi@nvidia.com>
parent 98842a03
...@@ -35,6 +35,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -35,6 +35,7 @@ from vllm.entrypoints.openai.protocol import (
) )
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.inputs.data import TokensPrompt
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from worker import VllmWorkers from worker import VllmWorkers
...@@ -78,9 +79,11 @@ class ServiceAPI: ...@@ -78,9 +79,11 @@ class ServiceAPI:
or self.http_client is None or self.http_client is None
): ):
return ErrorResponse( return ErrorResponse(
message="Service not ready", error={
type="service_unavailable", "message": "Service not ready",
code=503, "type": "service_unavailable",
"code": 503,
},
) )
try: try:
...@@ -95,9 +98,11 @@ class ServiceAPI: ...@@ -95,9 +98,11 @@ class ServiceAPI:
max_tokens_value = request.max_tokens max_tokens_value = request.max_tokens
else: else:
return ErrorResponse( return ErrorResponse(
message="Either max_tokens or max_completion_tokens must be specified", error={
type="invalid_request_error", "message": "Either max_tokens or max_completion_tokens must be specified",
code=400, "type": "invalid_request_error",
"code": 400,
},
) )
# Use vLLM's preprocessing to convert chat to prompt # Use vLLM's preprocessing to convert chat to prompt
...@@ -119,9 +124,9 @@ class ServiceAPI: ...@@ -119,9 +124,9 @@ class ServiceAPI:
# Convert request to sampling parameters with our determined max_tokens # Convert request to sampling parameters with our determined max_tokens
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
default_max_tokens=max_tokens_value, max_tokens=max_tokens_value,
logits_processor_pattern=None, logits_processor_pattern=None,
default_sampling_params=None, default_sampling_params={},
) )
# Get best worker using HTTP request to router # Get best worker using HTTP request to router
...@@ -129,9 +134,11 @@ class ServiceAPI: ...@@ -129,9 +134,11 @@ class ServiceAPI:
num_tokens = len(tokens) num_tokens = len(tokens)
if num_tokens == 0: if num_tokens == 0:
return ErrorResponse( return ErrorResponse(
message="Input prompt is empty", error={
type="invalid_request_error", "message": "Input prompt is empty",
code=400, "type": "invalid_request_error",
"code": 400,
}
) )
# It is much preferred to communicate block hashes to the router instead of # It is much preferred to communicate block hashes to the router instead of
...@@ -161,9 +168,11 @@ class ServiceAPI: ...@@ -161,9 +168,11 @@ class ServiceAPI:
except (httpx.RequestError, httpx.HTTPStatusError) as e: except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.error(f"Router request failed: {e}") logger.error(f"Router request failed: {e}")
return ErrorResponse( return ErrorResponse(
message="Router service unavailable", error={
type="service_unavailable", "message": "Router service unavailable",
code=503, "type": "service_unavailable",
"code": 503,
}
) )
logger.info(f"Selected worker {best_worker_id} for request") logger.info(f"Selected worker {best_worker_id} for request")
...@@ -172,9 +181,13 @@ class ServiceAPI: ...@@ -172,9 +181,13 @@ class ServiceAPI:
request_id = f"chatcmpl-{uuid.uuid4()}" request_id = f"chatcmpl-{uuid.uuid4()}"
request_metadata = RequestResponseMetadata(request_id=request_id) request_metadata = RequestResponseMetadata(request_id=request_id)
# Convert engine_prompt dict to TokensPrompt object
tokens_prompt = TokensPrompt(prompt_token_ids=tokens)
logger.info(f"Created TokensPrompt with {len(tokens)} tokens")
# Get the generator from the selected worker with sampling params # Get the generator from the selected worker with sampling params
result_generator = self.workers.direct( result_generator = self.workers.direct(
engine_prompt, best_worker_id, sampling_params tokens_prompt, best_worker_id, sampling_params
) )
assert request.stream assert request.stream
...@@ -188,6 +201,7 @@ class ServiceAPI: ...@@ -188,6 +201,7 @@ class ServiceAPI:
conversation, conversation,
self.tokenizer, self.tokenizer,
request_metadata, request_metadata,
enable_force_include_usage=False,
), ),
media_type="text/event-stream", media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
...@@ -195,7 +209,9 @@ class ServiceAPI: ...@@ -195,7 +209,9 @@ class ServiceAPI:
except Exception as e: except Exception as e:
logger.error(f"Error processing request: {e}") logger.error(f"Error processing request: {e}")
return ErrorResponse(message=str(e), type="internal_error", code=500) return ErrorResponse(
error={"message": str(e), "type": "internal_error", "code": 500}
)
async def initialize_services(self): async def initialize_services(self):
"""Initialize workers, HTTP client, and OpenAI serving components""" """Initialize workers, HTTP client, and OpenAI serving components"""
......
...@@ -41,7 +41,7 @@ class RouterResponse(BaseModel): ...@@ -41,7 +41,7 @@ class RouterResponse(BaseModel):
class LoadMetrics(BaseModel): class LoadMetrics(BaseModel):
gpu_cache_usage: float kv_cache_usage: float
num_waiting_reqs: int num_waiting_reqs: int
...@@ -101,7 +101,7 @@ class KvRouter: ...@@ -101,7 +101,7 @@ class KvRouter:
try: try:
metrics_dict = self.load_listeners[worker_id].recv_json(zmq.NOBLOCK) metrics_dict = self.load_listeners[worker_id].recv_json(zmq.NOBLOCK)
metrics = LoadMetrics.model_validate(metrics_dict) metrics = LoadMetrics.model_validate(metrics_dict)
self.kv_usages[worker_id] = metrics.gpu_cache_usage self.kv_usages[worker_id] = metrics.kv_cache_usage
self.waitings[worker_id] = metrics.num_waiting_reqs self.waitings[worker_id] = metrics.num_waiting_reqs
except zmq.Again: except zmq.Again:
pass pass
......
...@@ -20,7 +20,13 @@ import uuid ...@@ -20,7 +20,13 @@ import uuid
from typing import AsyncGenerator, Optional from typing import AsyncGenerator, Optional
import zmq import zmq
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.config import (
CacheConfig,
ModelConfig,
ObservabilityConfig,
SchedulerConfig,
VllmConfig,
)
from vllm.distributed.kv_events import KVEventsConfig from vllm.distributed.kv_events import KVEventsConfig
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
...@@ -50,7 +56,7 @@ class MetricsPublisher(StatLoggerBase): ...@@ -50,7 +56,7 @@ class MetricsPublisher(StatLoggerBase):
# Send metrics over ZMQ # Send metrics over ZMQ
metrics_data = { metrics_data = {
"num_waiting_reqs": scheduler_stats.num_waiting_reqs, "num_waiting_reqs": scheduler_stats.num_waiting_reqs,
"gpu_cache_usage": scheduler_stats.gpu_cache_usage, "kv_cache_usage": scheduler_stats.kv_cache_usage,
} }
self.socket.send_json(metrics_data) self.socket.send_json(metrics_data)
...@@ -108,11 +114,14 @@ class VllmWorkers: ...@@ -108,11 +114,14 @@ class VllmWorkers:
scheduler_cls="vllm.v1.core.sched.scheduler.Scheduler" scheduler_cls="vllm.v1.core.sched.scheduler.Scheduler"
) )
observability_config = ObservabilityConfig()
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=model_config, model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
kv_events_config=kv_events_config, kv_events_config=kv_events_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
observability_config=observability_config,
) )
self.llms.append( self.llms.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment