Unverified Commit 4447c247 authored by davilu-nvidia's avatar davilu-nvidia Committed by GitHub
Browse files

fix: resolve sgl be E/P/D multimodal routing issues (#5500)


Signed-off-by: default avatardavilu <davilu@nvidia.com>
parent 417ce216
......@@ -487,23 +487,13 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con
await pd_worker_client.wait_for_instances()
ready_event = asyncio.Event()
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
# Encode Worker is an internal component, should not register with Frontend
# Only needs to provide internal service endpoint for Processor to call
await generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[("model", server_args.served_model_name)],
),
register_llm_with_readiness_gate(
None, # encode worker doesn't have engine
generate_endpoint,
server_args,
dynamo_args,
input_type=ModelInput.Text,
readiness_gate=ready_event,
),
)
except Exception as e:
logging.error(f"Failed to serve endpoints: {e}")
......@@ -542,6 +532,17 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
ready_event = asyncio.Event()
try:
if config.serving_mode == DisaggregationMode.DECODE:
# Decode Worker is an internal component, should not register with Frontend
# Only needs to provide internal service endpoint for Processor to call
await generate_endpoint.serve_endpoint(
handler.generate,
metrics_labels=[("model", server_args.served_model_name)],
graceful_shutdown=True,
health_check_payload=health_check_payload,
)
else:
# In aggregated mode, need to register with Frontend
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
......@@ -580,23 +581,15 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co
await handler.async_init()
health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()
ready_event = asyncio.Event()
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
# Prefill Worker is an internal component, should not register with Frontend
# Only needs to provide internal service endpoint for Decode Worker to call
await generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[("model", server_args.served_model_name)],
health_check_payload=health_check_payload,
),
register_llm_with_readiness_gate(
engine,
generate_endpoint,
server_args,
dynamo_args,
readiness_gate=ready_event,
),
)
except Exception as e:
logging.error(f"Failed to serve endpoints: {e}")
......
......@@ -82,7 +82,18 @@ def is_model_supported(model_name: str, supported_model: str) -> bool:
normalized_name = normalize_model_name(model_name).lower()
normalized_supported = normalize_model_name(supported_model).lower()
return normalized_name == normalized_supported
# Exact match
if normalized_name == normalized_supported:
return True
# Handle local path case: compare only the model name part (without organization)
# e.g., "qwen2.5-vl-7b-instruct" matches "qwen/qwen2.5-vl-7b-instruct"
if "/" in normalized_supported:
model_part = normalized_supported.split("/")[-1]
if normalized_name == model_part:
return True
return False
def get_qwen_image_features(
......
......@@ -237,7 +237,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
config: Config,
prefill_client: Client = None,
):
super().__init__(component, engine, config, None, prefill_client)
super().__init__(component, engine, config, None)
# Initialize processors
self.embeddings_processor = EmbeddingsProcessor()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment