Unverified Commit dfda6205 authored by hhzhang16's avatar hhzhang16 Committed by GitHub
Browse files

fix: sglang -- queue requests until model registration completes (#2701)


Signed-off-by: default avatarHannah Zhang <hannahz@nvidia.com>
parent 6cf96e02
...@@ -81,18 +81,40 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -81,18 +81,40 @@ async def init(runtime: DistributedRuntime, config: Config):
logging.info(f"Setting up ZMQ kv event publisher at {zmq_ep}") logging.info(f"Setting up ZMQ kv event publisher at {zmq_ep}")
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config) kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
# Readiness gate: requests wait until model is registered
ready_event = asyncio.Event()
async def gated_generate(request):
"""Queue requests until model registration completes"""
await ready_event.wait() # Block until model is ready
async for response in handler.generate(request):
yield response
handler = DecodeWorkerHandler( handler = DecodeWorkerHandler(
component, engine, config, publisher, kv_publisher, prefill_client component, engine, config, publisher, kv_publisher, prefill_client
) )
await register_llm_with_runtime_config( async def register_model():
engine, generate_endpoint, server_args, dynamo_args.migration_limit """Register the model and signal readiness"""
) registration_success = await register_llm_with_runtime_config(
engine, generate_endpoint, server_args, dynamo_args.migration_limit
)
if not registration_success:
logging.error("Model registration failed; shutting down")
runtime.shutdown()
raise RuntimeError("Model registration failed")
# Model is ready - allow queued requests to proceed
ready_event.set()
logging.info("Model registration succeeded; processing queued requests")
try: try:
# TODO: add in native endpoints # Start endpoint immediately and register model concurrently
# Requests queue until ready_event is set
await asyncio.gather( await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate, graceful_shutdown=False), generate_endpoint.serve_endpoint(gated_generate, graceful_shutdown=False),
register_model(),
) )
except Exception as e: except Exception as e:
logging.error(f"Failed to serve endpoints: {e}") logging.error(f"Failed to serve endpoints: {e}")
......
...@@ -16,8 +16,12 @@ async def register_llm_with_runtime_config( ...@@ -16,8 +16,12 @@ async def register_llm_with_runtime_config(
endpoint: Endpoint, endpoint: Endpoint,
server_args: ServerArgs, server_args: ServerArgs,
migration_limit: int, migration_limit: int,
): ) -> bool:
"""Register LLM with runtime config""" """Register LLM with runtime config
Returns:
bool: True if registration succeeded, False if it failed
"""
runtime_config = await _get_runtime_config(engine) runtime_config = await _get_runtime_config(engine)
try: try:
await register_llm( await register_llm(
...@@ -29,9 +33,11 @@ async def register_llm_with_runtime_config( ...@@ -29,9 +33,11 @@ async def register_llm_with_runtime_config(
migration_limit=migration_limit, migration_limit=migration_limit,
runtime_config=runtime_config, runtime_config=runtime_config,
) )
logging.info("Successfully registered LLM with runtime config")
return True
except Exception as e: except Exception as e:
logging.error(f"Failed to register with runtime config: {e}") logging.error(f"Failed to register with runtime config: {e}")
return None return False
async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig]: async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment