Unverified Commit ad5afb7b authored by nancya-nv's avatar nancya-nv Committed by GitHub
Browse files

fix: Add model registration to SGLang multimodal workers fixing bug #4486 (#4512)


Signed-off-by: default avatarNancy Agarwal <nancya@nvidia.com>
Co-authored-by: default avatarKris Hung <krish@nvidia.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent d5f425ab
...@@ -430,16 +430,24 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con ...@@ -430,16 +430,24 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con
await pd_worker_client.wait_for_instances() await pd_worker_client.wait_for_instances()
tasks = [ ready_event = asyncio.Event()
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[("model", server_args.served_model_name)],
)
]
try: try:
await asyncio.gather(*tasks) await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[("model", server_args.served_model_name)],
),
register_llm_with_readiness_gate(
None, # encode worker doesn't have engine
generate_endpoint,
server_args,
dynamo_args,
input_type=ModelInput.Text,
readiness_gate=ready_event,
),
)
except Exception as e: except Exception as e:
logging.error(f"Failed to serve endpoints: {e}") logging.error(f"Failed to serve endpoints: {e}")
raise raise
...@@ -473,11 +481,24 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config): ...@@ -473,11 +481,24 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
await handler.async_init() await handler.async_init()
health_check_payload = SglangHealthCheckPayload(engine).to_dict()
ready_event = asyncio.Event()
try: try:
await generate_endpoint.serve_endpoint( await asyncio.gather(
handler.generate, generate_endpoint.serve_endpoint(
metrics_labels=[("model", server_args.served_model_name)], handler.generate,
graceful_shutdown=True, metrics_labels=[("model", server_args.served_model_name)],
graceful_shutdown=True,
health_check_payload=health_check_payload,
),
register_llm_with_readiness_gate(
engine,
generate_endpoint,
server_args,
dynamo_args,
readiness_gate=ready_event,
),
) )
except Exception as e: except Exception as e:
logging.error(f"Failed to serve endpoints: {e}") logging.error(f"Failed to serve endpoints: {e}")
...@@ -502,6 +523,7 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co ...@@ -502,6 +523,7 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co
await handler.async_init() await handler.async_init()
health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict() health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()
ready_event = asyncio.Event()
try: try:
await asyncio.gather( await asyncio.gather(
...@@ -510,7 +532,14 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co ...@@ -510,7 +532,14 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co
graceful_shutdown=True, graceful_shutdown=True,
metrics_labels=[("model", server_args.served_model_name)], metrics_labels=[("model", server_args.served_model_name)],
health_check_payload=health_check_payload, health_check_payload=health_check_payload,
) ),
register_llm_with_readiness_gate(
engine,
generate_endpoint,
server_args,
dynamo_args,
readiness_gate=ready_event,
),
) )
except Exception as e: except Exception as e:
logging.error(f"Failed to serve endpoints: {e}") logging.error(f"Failed to serve endpoints: {e}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment