"tests/fault_tolerance/vscode:/vscode.git/clone" did not exist on "3e0459fb0124c38c46820e0e68089c7c4d50ea28"
Unverified Commit dfda6205 authored by hhzhang16's avatar hhzhang16 Committed by GitHub
Browse files

fix: sglang -- queue requests until model registration completes (#2701)


Signed-off-by: default avatarHannah Zhang <hannahz@nvidia.com>
parent 6cf96e02
......@@ -81,18 +81,40 @@ async def init(runtime: DistributedRuntime, config: Config):
logging.info(f"Setting up ZMQ kv event publisher at {zmq_ep}")
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
# Readiness gate: requests wait until model is registered
ready_event = asyncio.Event()
async def gated_generate(request):
"""Queue requests until model registration completes"""
await ready_event.wait() # Block until model is ready
async for response in handler.generate(request):
yield response
handler = DecodeWorkerHandler(
component, engine, config, publisher, kv_publisher, prefill_client
)
await register_llm_with_runtime_config(
async def register_model():
"""Register the model and signal readiness"""
registration_success = await register_llm_with_runtime_config(
engine, generate_endpoint, server_args, dynamo_args.migration_limit
)
if not registration_success:
logging.error("Model registration failed; shutting down")
runtime.shutdown()
raise RuntimeError("Model registration failed")
# Model is ready - allow queued requests to proceed
ready_event.set()
logging.info("Model registration succeeded; processing queued requests")
try:
# TODO: add in native endpoints
# Start endpoint immediately and register model concurrently
# Requests queue until ready_event is set
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate, graceful_shutdown=False),
generate_endpoint.serve_endpoint(gated_generate, graceful_shutdown=False),
register_model(),
)
except Exception as e:
logging.error(f"Failed to serve endpoints: {e}")
......
......@@ -16,8 +16,12 @@ async def register_llm_with_runtime_config(
endpoint: Endpoint,
server_args: ServerArgs,
migration_limit: int,
):
"""Register LLM with runtime config"""
) -> bool:
"""Register LLM with runtime config
Returns:
bool: True if registration succeeded, False if it failed
"""
runtime_config = await _get_runtime_config(engine)
try:
await register_llm(
......@@ -29,9 +33,11 @@ async def register_llm_with_runtime_config(
migration_limit=migration_limit,
runtime_config=runtime_config,
)
logging.info("Successfully registered LLM with runtime config")
return True
except Exception as e:
logging.error(f"Failed to register with runtime config: {e}")
return None
return False
async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment